Spaces:
Sleeping
Sleeping
updated I'm proved version
Browse files- gaia_tools/__init__.py +24 -0
- gaia_tools/__pycache__/__init__.cpython-312.pyc +0 -0
- gaia_tools/__pycache__/__init__.cpython-313.pyc +0 -0
- gaia_tools/__pycache__/__init__.cpython-314.pyc +0 -0
- gaia_tools/__pycache__/code_executor.cpython-312.pyc +0 -0
- gaia_tools/__pycache__/code_executor.cpython-313.pyc +0 -0
- gaia_tools/__pycache__/code_executor.cpython-314.pyc +0 -0
- gaia_tools/__pycache__/dataset.cpython-312.pyc +0 -0
- gaia_tools/__pycache__/dataset.cpython-313.pyc +0 -0
- gaia_tools/__pycache__/dataset.cpython-314.pyc +0 -0
- gaia_tools/__pycache__/error_analysis.cpython-312.pyc +0 -0
- gaia_tools/__pycache__/error_analysis.cpython-313.pyc +0 -0
- gaia_tools/__pycache__/error_analysis.cpython-314.pyc +0 -0
- gaia_tools/__pycache__/multimodal.cpython-312.pyc +0 -0
- gaia_tools/code_executor.py +389 -0
- gaia_tools/dataset.py +160 -0
- gaia_tools/error_analysis.py +480 -0
- gaia_tools/file_processor.py +274 -0
- gaia_tools/multimodal.py +458 -0
- requirements.txt +3 -1
- speed_optimized_gaia_agent.py +339 -69
gaia_tools/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GAIA Tools Package
|
| 3 |
+
Tools and utilities for analyzing and improving GAIA agent performance.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from .error_analysis import (
|
| 7 |
+
GAIATestAnalyzer,
|
| 8 |
+
QuestionType,
|
| 9 |
+
FailureMode,
|
| 10 |
+
TestResult
|
| 11 |
+
)
|
| 12 |
+
from .dataset import (
|
| 13 |
+
GAIADatasetManager,
|
| 14 |
+
ensure_local_testing_setup
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'GAIATestAnalyzer',
|
| 19 |
+
'QuestionType',
|
| 20 |
+
'FailureMode',
|
| 21 |
+
'TestResult',
|
| 22 |
+
'GAIADatasetManager',
|
| 23 |
+
'ensure_local_testing_setup'
|
| 24 |
+
]
|
gaia_tools/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (583 Bytes). View file
|
|
|
gaia_tools/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (583 Bytes). View file
|
|
|
gaia_tools/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (580 Bytes). View file
|
|
|
gaia_tools/__pycache__/code_executor.cpython-312.pyc
ADDED
|
Binary file (15 kB). View file
|
|
|
gaia_tools/__pycache__/code_executor.cpython-313.pyc
ADDED
|
Binary file (14.2 kB). View file
|
|
|
gaia_tools/__pycache__/code_executor.cpython-314.pyc
ADDED
|
Binary file (17.2 kB). View file
|
|
|
gaia_tools/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (8.22 kB). View file
|
|
|
gaia_tools/__pycache__/dataset.cpython-313.pyc
ADDED
|
Binary file (8.36 kB). View file
|
|
|
gaia_tools/__pycache__/dataset.cpython-314.pyc
ADDED
|
Binary file (9.53 kB). View file
|
|
|
gaia_tools/__pycache__/error_analysis.cpython-312.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
gaia_tools/__pycache__/error_analysis.cpython-313.pyc
ADDED
|
Binary file (20.4 kB). View file
|
|
|
gaia_tools/__pycache__/error_analysis.cpython-314.pyc
ADDED
|
Binary file (23.6 kB). View file
|
|
|
gaia_tools/__pycache__/multimodal.cpython-312.pyc
ADDED
|
Binary file (15.7 kB). View file
|
|
|
gaia_tools/code_executor.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code Execution Framework for GAIA Agent
|
| 3 |
+
|
| 4 |
+
Provides safe Python code execution for math/data processing questions.
|
| 5 |
+
Uses local execution with timeout and safety constraints.
|
| 6 |
+
|
| 7 |
+
Expected Impact: +15-20% accuracy improvement on math/calculation questions
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import subprocess
|
| 15 |
+
import tempfile
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Optional, List
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@dataclass
|
| 22 |
+
class ExecutionResult:
|
| 23 |
+
"""Result of code execution"""
|
| 24 |
+
success: bool
|
| 25 |
+
output: Optional[str]
|
| 26 |
+
error: Optional[str]
|
| 27 |
+
execution_time: float
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def should_use_code_execution(question: str) -> bool:
|
| 31 |
+
"""
|
| 32 |
+
Determine if a question would benefit from code execution.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
question: The question text
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
True if code execution should be used
|
| 39 |
+
"""
|
| 40 |
+
question_lower = question.lower()
|
| 41 |
+
|
| 42 |
+
# EXCLUSIONS: Research questions that should NOT use code
|
| 43 |
+
research_indicators = [
|
| 44 |
+
'who', 'when', 'where', 'which person', 'which company',
|
| 45 |
+
'published by', 'written by', 'created by', 'founded by',
|
| 46 |
+
'according to', 'wikipedia', 'article', 'biography',
|
| 47 |
+
'history of', 'year of', 'born in', 'died in'
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
# If it's clearly a research/lookup question, don't use code
|
| 51 |
+
if any(indicator in question_lower for indicator in research_indicators):
|
| 52 |
+
# Exception: if it has actual numbers to calculate WITH
|
| 53 |
+
# e.g., "Who scored 25 + 30 points?" should use code for the math
|
| 54 |
+
has_math_operators = any(op in question for op in ['+', '-', '*', '/', '='])
|
| 55 |
+
if not has_math_operators:
|
| 56 |
+
return False
|
| 57 |
+
|
| 58 |
+
# Math keywords - direct operations
|
| 59 |
+
math_keywords = [
|
| 60 |
+
'calculate', 'compute', 'sum', 'average', 'mean', 'median',
|
| 61 |
+
'multiply', 'divide', 'subtract', 'add', 'total',
|
| 62 |
+
'square root', 'power', 'factorial', 'prime',
|
| 63 |
+
'+', '-', '*', '/', '%', '^', '='
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
# Check for math operations
|
| 67 |
+
if any(keyword in question_lower for keyword in math_keywords):
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
# Data processing keywords - only for provided data
|
| 71 |
+
data_processing_indicators = [
|
| 72 |
+
'from the csv', 'in the file', 'in the spreadsheet',
|
| 73 |
+
'from the table', 'in the data', 'given the values',
|
| 74 |
+
'calculate from', 'based on the following'
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
if any(indicator in question_lower for indicator in data_processing_indicators):
|
| 78 |
+
return True
|
| 79 |
+
|
| 80 |
+
# Check for explicit number sequences that need calculation
|
| 81 |
+
# e.g., "What is 123 * 456" or "Sum of 10, 20, 30"
|
| 82 |
+
numbers = re.findall(r'\d+', question)
|
| 83 |
+
has_operators = any(op in question for op in ['+', '-', '*', '/', '=', 'x'])
|
| 84 |
+
|
| 85 |
+
if len(numbers) >= 2 and has_operators:
|
| 86 |
+
return True
|
| 87 |
+
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CodeExecutor:
|
| 92 |
+
"""
|
| 93 |
+
Safe Python code executor with timeout and safety constraints.
|
| 94 |
+
|
| 95 |
+
Uses subprocess isolation to prevent harmful operations.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, timeout: int = 10, openrouter_client=None, model: str = "x-ai/grok-4.1-fast"):
|
| 99 |
+
"""
|
| 100 |
+
Initialize code executor.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
timeout: Maximum execution time in seconds
|
| 104 |
+
openrouter_client: OpenAI client for OpenRouter (for code generation)
|
| 105 |
+
model: Model to use for code generation
|
| 106 |
+
"""
|
| 107 |
+
self.timeout = timeout
|
| 108 |
+
self.openrouter_client = openrouter_client
|
| 109 |
+
self.model = model
|
| 110 |
+
|
| 111 |
+
def generate_code(self, question: str, context: Optional[str] = None) -> str:
|
| 112 |
+
"""
|
| 113 |
+
Generate Python code to answer the question.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
question: The question to solve
|
| 117 |
+
context: Optional context/data for the question
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Python code as string
|
| 121 |
+
"""
|
| 122 |
+
# If we have OpenRouter, use LLM to generate code
|
| 123 |
+
if self.openrouter_client:
|
| 124 |
+
return self._generate_code_with_llm(question, context)
|
| 125 |
+
|
| 126 |
+
# Fallback: Simple code generation for basic math
|
| 127 |
+
return self._generate_code_simple(question)
|
| 128 |
+
|
| 129 |
+
def _generate_code_with_llm(self, question: str, context: Optional[str] = None) -> str:
|
| 130 |
+
"""Generate code using LLM"""
|
| 131 |
+
prompt = f"""Generate Python code to answer this question. Output ONLY the Python code, no explanations.
|
| 132 |
+
The code must print the final answer using print().
|
| 133 |
+
|
| 134 |
+
Question: {question}"""
|
| 135 |
+
|
| 136 |
+
if context:
|
| 137 |
+
prompt += f"\n\nContext/Data: {context}"
|
| 138 |
+
|
| 139 |
+
prompt += """
|
| 140 |
+
|
| 141 |
+
Requirements:
|
| 142 |
+
1. Use only Python standard library (math, statistics, etc.)
|
| 143 |
+
2. Print the final answer
|
| 144 |
+
3. Keep it simple and direct
|
| 145 |
+
4. No external imports except math, statistics
|
| 146 |
+
5. Handle edge cases
|
| 147 |
+
|
| 148 |
+
Code:"""
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
response = self.openrouter_client.chat.completions.create(
|
| 152 |
+
model=self.model,
|
| 153 |
+
messages=[{"role": "user", "content": prompt}],
|
| 154 |
+
max_tokens=500,
|
| 155 |
+
temperature=0.1
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
code = response.choices[0].message.content.strip()
|
| 159 |
+
|
| 160 |
+
# Extract code from markdown if present
|
| 161 |
+
if "```python" in code:
|
| 162 |
+
code = code.split("```python")[1].split("```")[0].strip()
|
| 163 |
+
elif "```" in code:
|
| 164 |
+
code = code.split("```")[1].split("```")[0].strip()
|
| 165 |
+
|
| 166 |
+
return code
|
| 167 |
+
|
| 168 |
+
except Exception as e:
|
| 169 |
+
print(f"❌ LLM code generation failed: {e}")
|
| 170 |
+
return self._generate_code_simple(question)
|
| 171 |
+
|
| 172 |
+
def _generate_code_simple(self, question: str) -> str:
|
| 173 |
+
"""
|
| 174 |
+
Generate simple code without LLM (fallback).
|
| 175 |
+
|
| 176 |
+
This handles basic arithmetic expressions.
|
| 177 |
+
"""
|
| 178 |
+
# Try to extract a math expression
|
| 179 |
+
# Remove common words
|
| 180 |
+
expr = question.lower()
|
| 181 |
+
for word in ['what is', 'calculate', 'compute', 'the result of', '?', 'equal', 'equals']:
|
| 182 |
+
expr = expr.replace(word, ' ')
|
| 183 |
+
|
| 184 |
+
expr = expr.strip()
|
| 185 |
+
|
| 186 |
+
# Convert word operations to symbols
|
| 187 |
+
replacements = {
|
| 188 |
+
' plus ': '+',
|
| 189 |
+
' minus ': '-',
|
| 190 |
+
' times ': '*',
|
| 191 |
+
' divided by ': '/',
|
| 192 |
+
' multiply ': '*',
|
| 193 |
+
' divide ': '/',
|
| 194 |
+
' add ': '+',
|
| 195 |
+
' subtract ': '-'
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
for word, symbol in replacements.items():
|
| 199 |
+
expr = expr.replace(word, symbol)
|
| 200 |
+
|
| 201 |
+
# Clean up spaces
|
| 202 |
+
expr = re.sub(r'\s+', '', expr)
|
| 203 |
+
|
| 204 |
+
# Basic validation
|
| 205 |
+
if re.match(r'^[\d+\-*/().\s]+$', expr):
|
| 206 |
+
return f"result = {expr}\nprint(int(result) if result == int(result) else result)"
|
| 207 |
+
|
| 208 |
+
# Fallback for square root
|
| 209 |
+
if 'square root' in question.lower():
|
| 210 |
+
match = re.search(r'\d+', question)
|
| 211 |
+
if match:
|
| 212 |
+
num = match.group()
|
| 213 |
+
return f"import math\nresult = math.sqrt({num})\nprint(int(result) if result == int(result) else result)"
|
| 214 |
+
|
| 215 |
+
# Fallback for average
|
| 216 |
+
if 'average' in question.lower() or 'mean' in question.lower():
|
| 217 |
+
numbers = re.findall(r'\d+', question)
|
| 218 |
+
if numbers:
|
| 219 |
+
# Convert to integers explicitly
|
| 220 |
+
numbers_list = [int(n) for n in numbers]
|
| 221 |
+
return f"values = {numbers_list}\nresult = sum(values) / len(values)\nprint(int(result) if result == int(result) else result)"
|
| 222 |
+
|
| 223 |
+
# Default fallback
|
| 224 |
+
return "print('Unable to generate code for this question')"
|
| 225 |
+
|
| 226 |
+
def execute(self, code: str) -> ExecutionResult:
|
| 227 |
+
"""
|
| 228 |
+
Execute Python code safely with timeout.
|
| 229 |
+
|
| 230 |
+
Args:
|
| 231 |
+
code: Python code to execute
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
ExecutionResult with output or error
|
| 235 |
+
"""
|
| 236 |
+
start_time = time.time()
|
| 237 |
+
|
| 238 |
+
# Safety check: block dangerous operations
|
| 239 |
+
dangerous_patterns = {
|
| 240 |
+
'import os': 'os module',
|
| 241 |
+
'import subprocess': 'subprocess module',
|
| 242 |
+
'import sys': 'sys module',
|
| 243 |
+
'import urllib': 'urllib module',
|
| 244 |
+
'import requests': 'requests module',
|
| 245 |
+
'import http': 'http module',
|
| 246 |
+
'import socket': 'socket module',
|
| 247 |
+
'open(': 'file operations',
|
| 248 |
+
'__import__': '__import__ function',
|
| 249 |
+
'eval(': 'eval function',
|
| 250 |
+
'exec(': 'exec function',
|
| 251 |
+
'compile(': 'compile function',
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
code_lower = code.lower()
|
| 255 |
+
|
| 256 |
+
# Check for dangerous patterns
|
| 257 |
+
for pattern, name in dangerous_patterns.items():
|
| 258 |
+
if pattern in code_lower:
|
| 259 |
+
# Only allow math and statistics imports
|
| 260 |
+
if 'import' in pattern and pattern not in ['import math', 'import statistics']:
|
| 261 |
+
# Check if it's actually importing something safe
|
| 262 |
+
if not any(safe in code_lower for safe in ['import math', 'import statistics', 'import random', 'import datetime']):
|
| 263 |
+
if pattern in code_lower:
|
| 264 |
+
return ExecutionResult(
|
| 265 |
+
success=False,
|
| 266 |
+
output=None,
|
| 267 |
+
error=f"Security: {name} is not allowed",
|
| 268 |
+
execution_time=time.time() - start_time
|
| 269 |
+
)
|
| 270 |
+
# Block file/exec operations outright
|
| 271 |
+
elif pattern in ['open(', '__import__', 'eval(', 'exec(', 'compile(']:
|
| 272 |
+
return ExecutionResult(
|
| 273 |
+
success=False,
|
| 274 |
+
output=None,
|
| 275 |
+
error=f"Security: {name} is not allowed",
|
| 276 |
+
execution_time=time.time() - start_time
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
# Create temporary file for code
|
| 280 |
+
try:
|
| 281 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False) as f:
|
| 282 |
+
f.write(code)
|
| 283 |
+
code_file = f.name
|
| 284 |
+
|
| 285 |
+
# Execute with timeout using subprocess
|
| 286 |
+
try:
|
| 287 |
+
result = subprocess.run(
|
| 288 |
+
[sys.executable, code_file],
|
| 289 |
+
capture_output=True,
|
| 290 |
+
text=True,
|
| 291 |
+
timeout=self.timeout,
|
| 292 |
+
env={**os.environ, 'PYTHONPATH': str(Path(__file__).parent)}
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
execution_time = time.time() - start_time
|
| 296 |
+
|
| 297 |
+
if result.returncode == 0:
|
| 298 |
+
output = result.stdout.strip()
|
| 299 |
+
return ExecutionResult(
|
| 300 |
+
success=True,
|
| 301 |
+
output=output,
|
| 302 |
+
error=None,
|
| 303 |
+
execution_time=execution_time
|
| 304 |
+
)
|
| 305 |
+
else:
|
| 306 |
+
return ExecutionResult(
|
| 307 |
+
success=False,
|
| 308 |
+
output=None,
|
| 309 |
+
error=result.stderr.strip(),
|
| 310 |
+
execution_time=execution_time
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
except subprocess.TimeoutExpired:
|
| 314 |
+
return ExecutionResult(
|
| 315 |
+
success=False,
|
| 316 |
+
output=None,
|
| 317 |
+
error=f"Execution timeout ({self.timeout}s)",
|
| 318 |
+
execution_time=self.timeout
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
except Exception as e:
|
| 322 |
+
return ExecutionResult(
|
| 323 |
+
success=False,
|
| 324 |
+
output=None,
|
| 325 |
+
error=str(e),
|
| 326 |
+
execution_time=time.time() - start_time
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
finally:
|
| 330 |
+
# Clean up temp file
|
| 331 |
+
try:
|
| 332 |
+
if 'code_file' in locals():
|
| 333 |
+
os.unlink(code_file)
|
| 334 |
+
except:
|
| 335 |
+
pass
|
| 336 |
+
|
| 337 |
+
def solve_question(self, question: str, context: Optional[str] = None) -> Optional[str]:
|
| 338 |
+
"""
|
| 339 |
+
Complete workflow: generate code, execute, return answer.
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
question: Question to solve
|
| 343 |
+
context: Optional context
|
| 344 |
+
|
| 345 |
+
Returns:
|
| 346 |
+
Answer string or None if failed
|
| 347 |
+
"""
|
| 348 |
+
print(f" 🧮 CODE EXECUTION: {question[:60]}...")
|
| 349 |
+
|
| 350 |
+
# Generate code
|
| 351 |
+
code = self.generate_code(question, context)
|
| 352 |
+
print(f" 📝 Generated code ({len(code)} chars)")
|
| 353 |
+
|
| 354 |
+
# Execute code
|
| 355 |
+
result = self.execute(code)
|
| 356 |
+
|
| 357 |
+
if result.success and result.output:
|
| 358 |
+
print(f" ✅ Execution successful: {result.output}")
|
| 359 |
+
return result.output
|
| 360 |
+
else:
|
| 361 |
+
print(f" ❌ Execution failed: {result.error}")
|
| 362 |
+
return None
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
# Test the code executor
|
| 367 |
+
print("=" * 60)
|
| 368 |
+
print("Code Executor Test")
|
| 369 |
+
print("=" * 60)
|
| 370 |
+
|
| 371 |
+
executor = CodeExecutor()
|
| 372 |
+
|
| 373 |
+
# Test 1: Simple arithmetic
|
| 374 |
+
question1 = "What is 123 * 456?"
|
| 375 |
+
print(f"\nTest 1: {question1}")
|
| 376 |
+
answer1 = executor.solve_question(question1)
|
| 377 |
+
print(f"Answer: {answer1}")
|
| 378 |
+
|
| 379 |
+
# Test 2: Average
|
| 380 |
+
question2 = "What is the average of 10, 20, 30, 40, 50?"
|
| 381 |
+
print(f"\nTest 2: {question2}")
|
| 382 |
+
answer2 = executor.solve_question(question2)
|
| 383 |
+
print(f"Answer: {answer2}")
|
| 384 |
+
|
| 385 |
+
# Test 3: Square root
|
| 386 |
+
question3 = "What is the square root of 144?"
|
| 387 |
+
print(f"\nTest 3: {question3}")
|
| 388 |
+
answer3 = executor.solve_question(question3)
|
| 389 |
+
print(f"Answer: {answer3}")
|
gaia_tools/dataset.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GAIA Dataset Utilities
|
| 3 |
+
Download and cache GAIA questions for local testing
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import requests
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Dict, Any, Optional
|
| 11 |
+
from datetime import datetime
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class GAIADatasetManager:
|
| 15 |
+
"""Manages GAIA dataset download and local caching"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, cache_dir: str = "gaia_data"):
|
| 18 |
+
self.cache_dir = Path(cache_dir)
|
| 19 |
+
self.cache_dir.mkdir(exist_ok=True)
|
| 20 |
+
|
| 21 |
+
self.api_url = "https://agents-course-unit4-scoring.hf.space"
|
| 22 |
+
self.questions_url = f"{self.api_url}/questions"
|
| 23 |
+
self.submit_url = f"{self.api_url}/submit"
|
| 24 |
+
|
| 25 |
+
self.questions_cache_file = self.cache_dir / "questions.json"
|
| 26 |
+
self.metadata_file = self.cache_dir / "metadata.json"
|
| 27 |
+
|
| 28 |
+
def download_questions(self, force_refresh: bool = False) -> List[Dict[str, Any]]:
|
| 29 |
+
"""
|
| 30 |
+
Download GAIA questions from scoring API.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
force_refresh: If True, always download fresh data. If False, use cache if available.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
List of question dictionaries
|
| 37 |
+
"""
|
| 38 |
+
# Check cache first
|
| 39 |
+
if not force_refresh and self.questions_cache_file.exists():
|
| 40 |
+
print(f"📦 Loading questions from cache: {self.questions_cache_file}")
|
| 41 |
+
with open(self.questions_cache_file, 'r', encoding='utf-8') as f:
|
| 42 |
+
return json.load(f)
|
| 43 |
+
|
| 44 |
+
# Download from API
|
| 45 |
+
print(f"🌐 Downloading questions from: {self.questions_url}")
|
| 46 |
+
try:
|
| 47 |
+
response = requests.get(self.questions_url, timeout=30)
|
| 48 |
+
response.raise_for_status()
|
| 49 |
+
questions = response.json()
|
| 50 |
+
|
| 51 |
+
if not questions:
|
| 52 |
+
raise ValueError("Fetched questions list is empty")
|
| 53 |
+
|
| 54 |
+
# Cache the questions
|
| 55 |
+
with open(self.questions_cache_file, 'w', encoding='utf-8') as f:
|
| 56 |
+
json.dump(questions, f, indent=2)
|
| 57 |
+
|
| 58 |
+
# Update metadata
|
| 59 |
+
metadata = {
|
| 60 |
+
"download_time": datetime.now().isoformat(),
|
| 61 |
+
"question_count": len(questions),
|
| 62 |
+
"api_url": self.questions_url
|
| 63 |
+
}
|
| 64 |
+
with open(self.metadata_file, 'w', encoding='utf-8') as f:
|
| 65 |
+
json.dump(metadata, f, indent=2)
|
| 66 |
+
|
| 67 |
+
print(f"✅ Downloaded and cached {len(questions)} questions")
|
| 68 |
+
return questions
|
| 69 |
+
|
| 70 |
+
except requests.exceptions.RequestException as e:
|
| 71 |
+
print(f"❌ Error downloading questions: {e}")
|
| 72 |
+
|
| 73 |
+
# Fallback to cache if available
|
| 74 |
+
if self.questions_cache_file.exists():
|
| 75 |
+
print("📦 Falling back to cached questions")
|
| 76 |
+
with open(self.questions_cache_file, 'r', encoding='utf-8') as f:
|
| 77 |
+
return json.load(f)
|
| 78 |
+
else:
|
| 79 |
+
raise e
|
| 80 |
+
|
| 81 |
+
def get_cached_metadata(self) -> Optional[Dict[str, Any]]:
|
| 82 |
+
"""Get metadata about cached questions"""
|
| 83 |
+
if self.metadata_file.exists():
|
| 84 |
+
with open(self.metadata_file, 'r', encoding='utf-8') as f:
|
| 85 |
+
return json.load(f)
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
def save_results(self, results: List[Dict[str, Any]], filename: Optional[str] = None):
|
| 89 |
+
"""
|
| 90 |
+
Save test results to a file
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
results: List of result dictionaries
|
| 94 |
+
filename: Optional filename. If not provided, uses timestamp.
|
| 95 |
+
"""
|
| 96 |
+
if filename is None:
|
| 97 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 98 |
+
filename = f"results_{timestamp}.json"
|
| 99 |
+
|
| 100 |
+
filepath = self.cache_dir / filename
|
| 101 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 102 |
+
json.dump(results, f, indent=2)
|
| 103 |
+
|
| 104 |
+
print(f"💾 Results saved to: {filepath}")
|
| 105 |
+
return filepath
|
| 106 |
+
|
| 107 |
+
def load_dotenv(self):
|
| 108 |
+
"""Load environment variables from .env file"""
|
| 109 |
+
env_file = Path(".env")
|
| 110 |
+
if env_file.exists():
|
| 111 |
+
print("📄 Loading environment variables from .env")
|
| 112 |
+
with open(env_file, 'r') as f:
|
| 113 |
+
for line in f:
|
| 114 |
+
line = line.strip()
|
| 115 |
+
if line and not line.startswith('#') and '=' in line:
|
| 116 |
+
key, value = line.split('=', 1)
|
| 117 |
+
os.environ[key.strip()] = value.strip()
|
| 118 |
+
print("✅ Environment variables loaded")
|
| 119 |
+
else:
|
| 120 |
+
print("⚠️ No .env file found")
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def ensure_local_testing_setup() -> GAIADatasetManager:
|
| 124 |
+
"""
|
| 125 |
+
Ensure environment is set up for 100% local testing.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
GAIADatasetManager instance with questions cached
|
| 129 |
+
"""
|
| 130 |
+
print("🔧 Setting up for local testing...")
|
| 131 |
+
|
| 132 |
+
# Load environment variables
|
| 133 |
+
manager = GAIADatasetManager()
|
| 134 |
+
manager.load_dotenv()
|
| 135 |
+
|
| 136 |
+
# Download and cache questions
|
| 137 |
+
try:
|
| 138 |
+
questions = manager.download_questions()
|
| 139 |
+
print(f"✅ Local testing setup complete ({len(questions)} questions cached)")
|
| 140 |
+
except Exception as e:
|
| 141 |
+
print(f"❌ Failed to download questions: {e}")
|
| 142 |
+
raise e
|
| 143 |
+
|
| 144 |
+
return manager
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
if __name__ == "__main__":
|
| 148 |
+
# Test the dataset manager
|
| 149 |
+
print("=" * 60)
|
| 150 |
+
print("GAIA Dataset Manager Test")
|
| 151 |
+
print("=" * 60)
|
| 152 |
+
|
| 153 |
+
manager = ensure_local_testing_setup()
|
| 154 |
+
|
| 155 |
+
# Show cache metadata
|
| 156 |
+
metadata = manager.get_cached_metadata()
|
| 157 |
+
if metadata:
|
| 158 |
+
print("\n📊 Cache Metadata:")
|
| 159 |
+
for key, value in metadata.items():
|
| 160 |
+
print(f" {key}: {value}")
|
gaia_tools/error_analysis.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GAIA Error Analysis Framework
|
| 3 |
+
|
| 4 |
+
Categorizes questions, failure modes, and generates actionable improvement recommendations.
|
| 5 |
+
Implements TDD test suite specifications from tests/test_error_analysis.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import csv
|
| 9 |
+
import json
|
| 10 |
+
import re
|
| 11 |
+
from dataclasses import dataclass, asdict
|
| 12 |
+
from enum import Enum
|
| 13 |
+
from typing import List, Dict, Optional, Any
|
| 14 |
+
from collections import defaultdict, Counter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class QuestionType(Enum):
|
| 18 |
+
"""Categories of GAIA questions"""
|
| 19 |
+
MATH = "math"
|
| 20 |
+
FILE = "file"
|
| 21 |
+
WEB = "web"
|
| 22 |
+
IMAGE = "image"
|
| 23 |
+
AUDIO = "audio"
|
| 24 |
+
REASONING = "reasoning"
|
| 25 |
+
MULTIMODAL = "multimodal"
|
| 26 |
+
UNKNOWN = "unknown"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class FailureMode(Enum):
|
| 30 |
+
"""Categories of answer failures"""
|
| 31 |
+
WRONG_ANSWER = "wrong_answer"
|
| 32 |
+
FORMATTING_ERROR = "formatting_error"
|
| 33 |
+
TIMEOUT = "timeout"
|
| 34 |
+
TOOL_FAILURE = "tool_failure"
|
| 35 |
+
EMPTY_RESPONSE = "empty_response"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class TestResult:
|
| 40 |
+
"""Represents a single test result"""
|
| 41 |
+
question_id: str
|
| 42 |
+
question: str
|
| 43 |
+
question_type: QuestionType
|
| 44 |
+
expected: str
|
| 45 |
+
actual: str
|
| 46 |
+
success: bool
|
| 47 |
+
failure_mode: Optional[FailureMode] = None
|
| 48 |
+
time_elapsed: float = 0.0
|
| 49 |
+
tools_used: Optional[List[str]] = None
|
| 50 |
+
error: Optional[Exception] = None
|
| 51 |
+
|
| 52 |
+
def __post_init__(self):
|
| 53 |
+
if self.tools_used is None:
|
| 54 |
+
self.tools_used = []
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class GAIATestAnalyzer:
|
| 58 |
+
"""
|
| 59 |
+
Analyzes GAIA agent test results to identify failure patterns and recommend improvements.
|
| 60 |
+
|
| 61 |
+
This class implements error categorization, performance tracking, and reporting
|
| 62 |
+
to guide agent optimization efforts.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def __init__(self):
|
| 66 |
+
self.results: List[TestResult] = []
|
| 67 |
+
|
| 68 |
+
# Patterns for question classification
|
| 69 |
+
self.math_patterns = [
|
| 70 |
+
r'\d+\s*[\+\-\*\/]\s*\d+', # Arithmetic operations with numbers
|
| 71 |
+
r'calculate|compute|sum|multiply|divide|subtract|add',
|
| 72 |
+
r'what is \d+',
|
| 73 |
+
r'how many|how much'
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
self.file_patterns = [
|
| 77 |
+
r'pdf|csv|excel|spreadsheet|document|table|file',
|
| 78 |
+
r'attached|according to the',
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
self.image_patterns = [
|
| 82 |
+
r'image|picture|photo|screenshot|attached.*color|in the (attached )?image'
|
| 83 |
+
]
|
| 84 |
+
|
| 85 |
+
self.audio_patterns = [
|
| 86 |
+
r'audio|recording|sound|said in|spoken|voice'
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
self.web_patterns = [
|
| 90 |
+
r'who is|what is the (current|latest)|CEO|president|founded|website',
|
| 91 |
+
r'according to.*wikipedia|look up'
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
self.reasoning_patterns = [
|
| 95 |
+
r'if .+ then|taller than|shorter than|before|after',
|
| 96 |
+
r'who is the (tallest|shortest|oldest|youngest)',
|
| 97 |
+
]
|
| 98 |
+
|
| 99 |
+
self.multimodal_patterns = [
|
| 100 |
+
r'(image|picture|photo).*(csv|file|data|spreadsheet)',
|
| 101 |
+
r'(csv|file|data|spreadsheet).*(image|picture|photo)',
|
| 102 |
+
r'using the .+ and the'
|
| 103 |
+
]
|
| 104 |
+
|
| 105 |
+
def classify_question_type(self, question: str) -> QuestionType:
|
| 106 |
+
"""
|
| 107 |
+
Classify a question into a QuestionType based on its content.
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
question: The question text to classify
|
| 111 |
+
|
| 112 |
+
Returns:
|
| 113 |
+
QuestionType enum value
|
| 114 |
+
"""
|
| 115 |
+
question_lower = question.lower()
|
| 116 |
+
|
| 117 |
+
# Check multimodal first (highest priority)
|
| 118 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 119 |
+
for pattern in self.multimodal_patterns):
|
| 120 |
+
return QuestionType.MULTIMODAL
|
| 121 |
+
|
| 122 |
+
# Check for image questions
|
| 123 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 124 |
+
for pattern in self.image_patterns):
|
| 125 |
+
return QuestionType.IMAGE
|
| 126 |
+
|
| 127 |
+
# Check for audio questions
|
| 128 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 129 |
+
for pattern in self.audio_patterns):
|
| 130 |
+
return QuestionType.AUDIO
|
| 131 |
+
|
| 132 |
+
# Check for file questions
|
| 133 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 134 |
+
for pattern in self.file_patterns):
|
| 135 |
+
return QuestionType.FILE
|
| 136 |
+
|
| 137 |
+
# Check for math questions
|
| 138 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 139 |
+
for pattern in self.math_patterns):
|
| 140 |
+
return QuestionType.MATH
|
| 141 |
+
|
| 142 |
+
# Check for reasoning questions
|
| 143 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 144 |
+
for pattern in self.reasoning_patterns):
|
| 145 |
+
return QuestionType.REASONING
|
| 146 |
+
|
| 147 |
+
# Check for web search questions
|
| 148 |
+
if any(re.search(pattern, question_lower, re.IGNORECASE)
|
| 149 |
+
for pattern in self.web_patterns):
|
| 150 |
+
return QuestionType.WEB
|
| 151 |
+
|
| 152 |
+
return QuestionType.UNKNOWN
|
| 153 |
+
|
| 154 |
+
def classify_failure_mode(
|
| 155 |
+
self,
|
| 156 |
+
expected: str,
|
| 157 |
+
actual: Optional[str],
|
| 158 |
+
error: Optional[Exception] = None
|
| 159 |
+
) -> FailureMode:
|
| 160 |
+
"""
|
| 161 |
+
Classify why an answer failed.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
expected: The correct answer
|
| 165 |
+
actual: The agent's answer (None if error occurred)
|
| 166 |
+
error: Exception if one occurred
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
FailureMode enum value
|
| 170 |
+
"""
|
| 171 |
+
# Check for exceptions first
|
| 172 |
+
if error is not None:
|
| 173 |
+
if isinstance(error, TimeoutError):
|
| 174 |
+
return FailureMode.TIMEOUT
|
| 175 |
+
else:
|
| 176 |
+
return FailureMode.TOOL_FAILURE
|
| 177 |
+
|
| 178 |
+
# Check for empty/unable responses
|
| 179 |
+
if actual is None or actual.strip() == "":
|
| 180 |
+
return FailureMode.EMPTY_RESPONSE
|
| 181 |
+
|
| 182 |
+
if "unable to determine" in actual.lower():
|
| 183 |
+
return FailureMode.EMPTY_RESPONSE
|
| 184 |
+
|
| 185 |
+
# Check for formatting errors
|
| 186 |
+
expected_clean = expected.strip().lower()
|
| 187 |
+
actual_clean = actual.strip().lower()
|
| 188 |
+
|
| 189 |
+
# Remove commas and check if answers match
|
| 190 |
+
expected_no_comma = expected_clean.replace(',', '')
|
| 191 |
+
actual_no_comma = actual_clean.replace(',', '')
|
| 192 |
+
if expected_no_comma == actual_no_comma and expected_clean != actual_clean:
|
| 193 |
+
return FailureMode.FORMATTING_ERROR
|
| 194 |
+
|
| 195 |
+
# Check for unwanted units
|
| 196 |
+
if actual_clean.startswith(expected_clean):
|
| 197 |
+
remainder = actual_clean[len(expected_clean):].strip()
|
| 198 |
+
if remainder: # Has extra content (likely units)
|
| 199 |
+
return FailureMode.FORMATTING_ERROR
|
| 200 |
+
|
| 201 |
+
# Check for articles (the, a, an)
|
| 202 |
+
articles = ['the ', 'a ', 'an ']
|
| 203 |
+
for article in articles:
|
| 204 |
+
if actual_clean.startswith(article):
|
| 205 |
+
without_article = actual_clean[len(article):]
|
| 206 |
+
if without_article == expected_clean:
|
| 207 |
+
return FailureMode.FORMATTING_ERROR
|
| 208 |
+
|
| 209 |
+
# If none of the above, it's a wrong answer
|
| 210 |
+
return FailureMode.WRONG_ANSWER
|
| 211 |
+
|
| 212 |
+
def log_result(self, result: TestResult):
|
| 213 |
+
"""
|
| 214 |
+
Add a test result to the analyzer.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
result: TestResult object to log
|
| 218 |
+
"""
|
| 219 |
+
self.results.append(result)
|
| 220 |
+
|
| 221 |
+
def analyze_response(
|
| 222 |
+
self,
|
| 223 |
+
question_id: str,
|
| 224 |
+
question: str,
|
| 225 |
+
expected: str,
|
| 226 |
+
actual: str,
|
| 227 |
+
time_elapsed: float = 0.0,
|
| 228 |
+
tools_used: Optional[List[str]] = None,
|
| 229 |
+
error: Optional[Exception] = None
|
| 230 |
+
) -> TestResult:
|
| 231 |
+
"""
|
| 232 |
+
Analyze a single agent response and create a TestResult.
|
| 233 |
+
|
| 234 |
+
This is a convenience method that combines classification and logging.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
question_id: Unique identifier for the question
|
| 238 |
+
question: The question text
|
| 239 |
+
expected: The correct answer
|
| 240 |
+
actual: The agent's answer
|
| 241 |
+
time_elapsed: Time taken to answer
|
| 242 |
+
tools_used: List of tools used by the agent
|
| 243 |
+
error: Exception if one occurred
|
| 244 |
+
|
| 245 |
+
Returns:
|
| 246 |
+
TestResult object with all classifications
|
| 247 |
+
"""
|
| 248 |
+
question_type = self.classify_question_type(question)
|
| 249 |
+
success = (actual == expected) if actual is not None else False
|
| 250 |
+
|
| 251 |
+
failure_mode = None
|
| 252 |
+
if not success:
|
| 253 |
+
failure_mode = self.classify_failure_mode(expected, actual, error)
|
| 254 |
+
|
| 255 |
+
result = TestResult(
|
| 256 |
+
question_id=question_id,
|
| 257 |
+
question=question,
|
| 258 |
+
question_type=question_type,
|
| 259 |
+
expected=expected,
|
| 260 |
+
actual=actual,
|
| 261 |
+
success=success,
|
| 262 |
+
failure_mode=failure_mode,
|
| 263 |
+
time_elapsed=time_elapsed,
|
| 264 |
+
tools_used=tools_used or [],
|
| 265 |
+
error=error
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
self.log_result(result)
|
| 269 |
+
return result
|
| 270 |
+
|
| 271 |
+
def generate_summary(self) -> Dict[str, Any]:
|
| 272 |
+
"""
|
| 273 |
+
Generate summary statistics for all logged results.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
Dictionary with summary statistics
|
| 277 |
+
"""
|
| 278 |
+
if not self.results:
|
| 279 |
+
return {
|
| 280 |
+
"total_questions": 0,
|
| 281 |
+
"correct_count": 0,
|
| 282 |
+
"accuracy": 0.0,
|
| 283 |
+
"avg_time": 0.0
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
total = len(self.results)
|
| 287 |
+
correct = sum(1 for r in self.results if r.success)
|
| 288 |
+
total_time = sum(r.time_elapsed for r in self.results)
|
| 289 |
+
|
| 290 |
+
return {
|
| 291 |
+
"total_questions": total,
|
| 292 |
+
"correct_count": correct,
|
| 293 |
+
"accuracy": correct / total if total > 0 else 0.0,
|
| 294 |
+
"avg_time": total_time / total if total > 0 else 0.0
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
def get_accuracy_by_type(self) -> Dict[QuestionType, float]:
|
| 298 |
+
"""
|
| 299 |
+
Calculate accuracy broken down by question type.
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Dictionary mapping QuestionType to accuracy (0.0-1.0)
|
| 303 |
+
"""
|
| 304 |
+
type_stats = defaultdict(lambda: {"correct": 0, "total": 0})
|
| 305 |
+
|
| 306 |
+
for result in self.results:
|
| 307 |
+
stats = type_stats[result.question_type]
|
| 308 |
+
stats["total"] += 1
|
| 309 |
+
if result.success:
|
| 310 |
+
stats["correct"] += 1
|
| 311 |
+
|
| 312 |
+
accuracy_by_type = {}
|
| 313 |
+
for qtype, stats in type_stats.items():
|
| 314 |
+
accuracy_by_type[qtype] = (
|
| 315 |
+
stats["correct"] / stats["total"] if stats["total"] > 0 else 0.0
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
return accuracy_by_type
|
| 319 |
+
|
| 320 |
+
def get_failures_by_mode(self) -> Dict[FailureMode, int]:
|
| 321 |
+
"""
|
| 322 |
+
Count failures by failure mode.
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
Dictionary mapping FailureMode to count
|
| 326 |
+
"""
|
| 327 |
+
failure_counts = Counter()
|
| 328 |
+
|
| 329 |
+
for result in self.results:
|
| 330 |
+
if not result.success and result.failure_mode:
|
| 331 |
+
failure_counts[result.failure_mode] += 1
|
| 332 |
+
|
| 333 |
+
return dict(failure_counts)
|
| 334 |
+
|
| 335 |
+
def export_to_csv(self, filepath: str):
|
| 336 |
+
"""
|
| 337 |
+
Export all results to a CSV file.
|
| 338 |
+
|
| 339 |
+
Args:
|
| 340 |
+
filepath: Path to output CSV file
|
| 341 |
+
"""
|
| 342 |
+
with open(filepath, 'w', newline='', encoding='utf-8') as f:
|
| 343 |
+
writer = csv.writer(f)
|
| 344 |
+
|
| 345 |
+
# Write header
|
| 346 |
+
writer.writerow([
|
| 347 |
+
'question_id', 'question', 'question_type', 'expected', 'actual',
|
| 348 |
+
'success', 'failure_mode', 'time_elapsed', 'tools_used'
|
| 349 |
+
])
|
| 350 |
+
|
| 351 |
+
# Write results
|
| 352 |
+
for result in self.results:
|
| 353 |
+
writer.writerow([
|
| 354 |
+
result.question_id,
|
| 355 |
+
result.question,
|
| 356 |
+
result.question_type.value.upper(),
|
| 357 |
+
result.expected,
|
| 358 |
+
result.actual,
|
| 359 |
+
result.success,
|
| 360 |
+
result.failure_mode.value.upper() if result.failure_mode else '',
|
| 361 |
+
result.time_elapsed,
|
| 362 |
+
','.join(result.tools_used) if result.tools_used else ''
|
| 363 |
+
])
|
| 364 |
+
|
| 365 |
+
def export_to_json(self, filepath: str):
|
| 366 |
+
"""
|
| 367 |
+
Export all results and summary to a JSON file.
|
| 368 |
+
|
| 369 |
+
Args:
|
| 370 |
+
filepath: Path to output JSON file
|
| 371 |
+
"""
|
| 372 |
+
data = {
|
| 373 |
+
"summary": self.generate_summary(),
|
| 374 |
+
"accuracy_by_type": {
|
| 375 |
+
qtype.value: acc
|
| 376 |
+
for qtype, acc in self.get_accuracy_by_type().items()
|
| 377 |
+
},
|
| 378 |
+
"failures_by_mode": {
|
| 379 |
+
mode.value: count
|
| 380 |
+
for mode, count in self.get_failures_by_mode().items()
|
| 381 |
+
},
|
| 382 |
+
"results": [
|
| 383 |
+
{
|
| 384 |
+
"question_id": r.question_id,
|
| 385 |
+
"question": r.question,
|
| 386 |
+
"question_type": r.question_type.value,
|
| 387 |
+
"expected": r.expected,
|
| 388 |
+
"actual": r.actual,
|
| 389 |
+
"success": r.success,
|
| 390 |
+
"failure_mode": r.failure_mode.value if r.failure_mode else None,
|
| 391 |
+
"time_elapsed": r.time_elapsed,
|
| 392 |
+
"tools_used": r.tools_used
|
| 393 |
+
}
|
| 394 |
+
for r in self.results
|
| 395 |
+
]
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
with open(filepath, 'w', encoding='utf-8') as f:
|
| 399 |
+
json.dump(data, f, indent=2)
|
| 400 |
+
|
| 401 |
+
def get_recommendations(self) -> List[str]:
|
| 402 |
+
"""
|
| 403 |
+
Generate actionable recommendations based on failure analysis.
|
| 404 |
+
|
| 405 |
+
Returns:
|
| 406 |
+
List of recommendation strings
|
| 407 |
+
"""
|
| 408 |
+
recommendations = []
|
| 409 |
+
|
| 410 |
+
# Analyze question types with low accuracy
|
| 411 |
+
accuracy_by_type = self.get_accuracy_by_type()
|
| 412 |
+
failures_by_mode = self.get_failures_by_mode()
|
| 413 |
+
|
| 414 |
+
# Check for image-related failures
|
| 415 |
+
image_results = [r for r in self.results if r.question_type == QuestionType.IMAGE]
|
| 416 |
+
if image_results:
|
| 417 |
+
image_accuracy = accuracy_by_type.get(QuestionType.IMAGE, 0.0)
|
| 418 |
+
if image_accuracy < 0.5:
|
| 419 |
+
recommendations.append(
|
| 420 |
+
"Add vision capabilities (Gemini 2.5 Pro) to handle image questions"
|
| 421 |
+
)
|
| 422 |
+
|
| 423 |
+
# Check for file processing failures
|
| 424 |
+
file_results = [r for r in self.results if r.question_type == QuestionType.FILE]
|
| 425 |
+
if file_results:
|
| 426 |
+
file_accuracy = accuracy_by_type.get(QuestionType.FILE, 0.0)
|
| 427 |
+
if file_accuracy < 0.5:
|
| 428 |
+
recommendations.append(
|
| 429 |
+
"Implement file processing capabilities (PDF/CSV/Excel parsing)"
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
# Check for math failures
|
| 433 |
+
math_results = [r for r in self.results if r.question_type == QuestionType.MATH]
|
| 434 |
+
if math_results:
|
| 435 |
+
math_accuracy = accuracy_by_type.get(QuestionType.MATH, 0.0)
|
| 436 |
+
if math_accuracy < 0.7:
|
| 437 |
+
recommendations.append(
|
| 438 |
+
"Add code execution capabilities for reliable math calculations"
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Check for formatting errors
|
| 442 |
+
formatting_errors = failures_by_mode.get(FailureMode.FORMATTING_ERROR, 0)
|
| 443 |
+
if formatting_errors > len(self.results) * 0.1: # More than 10% formatting errors
|
| 444 |
+
recommendations.append(
|
| 445 |
+
"Improve answer formatting logic to handle commas, units, and articles"
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
# Check for empty responses
|
| 449 |
+
empty_responses = failures_by_mode.get(FailureMode.EMPTY_RESPONSE, 0)
|
| 450 |
+
if empty_responses > len(self.results) * 0.1:
|
| 451 |
+
recommendations.append(
|
| 452 |
+
"Improve tool reliability and add fallback mechanisms for empty responses"
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# Check for timeouts
|
| 456 |
+
timeouts = failures_by_mode.get(FailureMode.TIMEOUT, 0)
|
| 457 |
+
if timeouts > len(self.results) * 0.05:
|
| 458 |
+
recommendations.append(
|
| 459 |
+
"Optimize query speed and increase timeout thresholds for complex questions"
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Check for audio processing
|
| 463 |
+
audio_results = [r for r in self.results if r.question_type == QuestionType.AUDIO]
|
| 464 |
+
if audio_results:
|
| 465 |
+
audio_accuracy = accuracy_by_type.get(QuestionType.AUDIO, 0.0)
|
| 466 |
+
if audio_accuracy < 0.5:
|
| 467 |
+
recommendations.append(
|
| 468 |
+
"Add audio transcription capabilities (Whisper)"
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
# Check for multimodal questions
|
| 472 |
+
multimodal_results = [r for r in self.results if r.question_type == QuestionType.MULTIMODAL]
|
| 473 |
+
if multimodal_results:
|
| 474 |
+
multimodal_accuracy = accuracy_by_type.get(QuestionType.MULTIMODAL, 0.0)
|
| 475 |
+
if multimodal_accuracy < 0.5:
|
| 476 |
+
recommendations.append(
|
| 477 |
+
"Improve multimodal reasoning by integrating multiple tool outputs"
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
return recommendations
|
gaia_tools/file_processor.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
File Processing Framework for GAIA Agent
|
| 3 |
+
|
| 4 |
+
Handles PDF, CSV, Excel, images, and audio files for GAIA questions.
|
| 5 |
+
|
| 6 |
+
Expected Impact: +10-15% accuracy improvement on file-based questions
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import re
|
| 10 |
+
import os
|
| 11 |
+
import io
|
| 12 |
+
from typing import Optional, Dict, Any, List
|
| 13 |
+
from dataclasses import dataclass
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
import tempfile
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class ProcessedFile:
|
| 20 |
+
"""Result of file processing"""
|
| 21 |
+
success: bool
|
| 22 |
+
file_type: str
|
| 23 |
+
content: Optional[str]
|
| 24 |
+
metadata: Dict[str, Any]
|
| 25 |
+
error: Optional[str] = None
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def extract_file_references(question: str) -> List[str]:
|
| 29 |
+
"""
|
| 30 |
+
Extract file references from a question.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
question: Question text
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
List of file references/URLs found
|
| 37 |
+
"""
|
| 38 |
+
references = []
|
| 39 |
+
|
| 40 |
+
# Look for file mentions
|
| 41 |
+
file_patterns = [
|
| 42 |
+
r'(attached|the)\s+(PDF|CSV|Excel|spreadsheet|image|picture|photo|audio|file)',
|
| 43 |
+
r'\.(pdf|csv|xlsx|xls|png|jpg|jpeg|gif|mp3|wav|m4a)',
|
| 44 |
+
r'https?://[^\s]+\.(pdf|csv|xlsx|png|jpg|jpeg)'
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
for pattern in file_patterns:
|
| 48 |
+
matches = re.findall(pattern, question, re.IGNORECASE)
|
| 49 |
+
references.extend(matches)
|
| 50 |
+
|
| 51 |
+
return list(set(references))
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def should_use_file_processing(question: str) -> bool:
|
| 55 |
+
"""Determine if question requires file processing"""
|
| 56 |
+
file_keywords = [
|
| 57 |
+
'attached', 'pdf', 'csv', 'excel', 'spreadsheet',
|
| 58 |
+
'image', 'picture', 'photo', 'document', 'file',
|
| 59 |
+
'table', 'according to the'
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
question_lower = question.lower()
|
| 63 |
+
return any(keyword in question_lower for keyword in file_keywords)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class FileProcessor:
|
| 67 |
+
"""
|
| 68 |
+
Multi-format file processor for GAIA questions.
|
| 69 |
+
|
| 70 |
+
Supports: PDF, CSV, Excel, Images (OCR), Audio (transcription)
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def __init__(self):
|
| 74 |
+
self.supported_formats = ['pdf', 'csv', 'xlsx', 'xls', 'png', 'jpg', 'jpeg', 'gif', 'mp3', 'wav']
|
| 75 |
+
|
| 76 |
+
def process_file(self, file_path: str) -> ProcessedFile:
|
| 77 |
+
"""
|
| 78 |
+
Process a file and extract its content.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
file_path: Path to the file
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
ProcessedFile with extracted content
|
| 85 |
+
"""
|
| 86 |
+
if not os.path.exists(file_path):
|
| 87 |
+
return ProcessedFile(
|
| 88 |
+
success=False,
|
| 89 |
+
file_type='unknown',
|
| 90 |
+
content=None,
|
| 91 |
+
metadata={},
|
| 92 |
+
error=f"File not found: {file_path}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
# Determine file type
|
| 96 |
+
ext = Path(file_path).suffix.lower().lstrip('.')
|
| 97 |
+
|
| 98 |
+
if ext == 'pdf':
|
| 99 |
+
return self._process_pdf(file_path)
|
| 100 |
+
elif ext in ['csv']:
|
| 101 |
+
return self._process_csv(file_path)
|
| 102 |
+
elif ext in ['xlsx', 'xls']:
|
| 103 |
+
return self._process_excel(file_path)
|
| 104 |
+
elif ext in ['png', 'jpg', 'jpeg', 'gif']:
|
| 105 |
+
return self._process_image(file_path)
|
| 106 |
+
elif ext in ['mp3', 'wav', 'm4a']:
|
| 107 |
+
return self._process_audio(file_path)
|
| 108 |
+
else:
|
| 109 |
+
return ProcessedFile(
|
| 110 |
+
success=False,
|
| 111 |
+
file_type=ext,
|
| 112 |
+
content=None,
|
| 113 |
+
metadata={},
|
| 114 |
+
error=f"Unsupported file type: {ext}"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _process_pdf(self, file_path: str) -> ProcessedFile:
|
| 118 |
+
"""Process PDF file"""
|
| 119 |
+
try:
|
| 120 |
+
# Try using pandas for simple PDFs (tables)
|
| 121 |
+
import pandas as pd
|
| 122 |
+
try:
|
| 123 |
+
# Try reading as table
|
| 124 |
+
tables = pd.read_html(file_path)
|
| 125 |
+
if tables:
|
| 126 |
+
content = "\n\n".join([table.to_string() for table in tables])
|
| 127 |
+
return ProcessedFile(
|
| 128 |
+
success=True,
|
| 129 |
+
file_type='pdf',
|
| 130 |
+
content=content,
|
| 131 |
+
metadata={'tables_found': len(tables)}
|
| 132 |
+
)
|
| 133 |
+
except:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
# Fallback: Simple text extraction message
|
| 137 |
+
return ProcessedFile(
|
| 138 |
+
success=False,
|
| 139 |
+
file_type='pdf',
|
| 140 |
+
content=None,
|
| 141 |
+
metadata={},
|
| 142 |
+
error="PDF processing requires PyPDF2 or similar library"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
except Exception as e:
|
| 146 |
+
return ProcessedFile(
|
| 147 |
+
success=False,
|
| 148 |
+
file_type='pdf',
|
| 149 |
+
content=None,
|
| 150 |
+
metadata={},
|
| 151 |
+
error=str(e)
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def _process_csv(self, file_path: str) -> ProcessedFile:
|
| 155 |
+
"""Process CSV file"""
|
| 156 |
+
try:
|
| 157 |
+
import pandas as pd
|
| 158 |
+
|
| 159 |
+
df = pd.read_csv(file_path)
|
| 160 |
+
|
| 161 |
+
# Generate summary
|
| 162 |
+
summary = f"CSV File Summary:\n"
|
| 163 |
+
summary += f"Rows: {len(df)}\n"
|
| 164 |
+
summary += f"Columns: {list(df.columns)}\n\n"
|
| 165 |
+
summary += f"First 10 rows:\n{df.head(10).to_string()}\n\n"
|
| 166 |
+
summary += f"Statistics:\n{df.describe().to_string()}"
|
| 167 |
+
|
| 168 |
+
return ProcessedFile(
|
| 169 |
+
success=True,
|
| 170 |
+
file_type='csv',
|
| 171 |
+
content=summary,
|
| 172 |
+
metadata={
|
| 173 |
+
'rows': len(df),
|
| 174 |
+
'columns': list(df.columns),
|
| 175 |
+
'shape': df.shape
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
return ProcessedFile(
|
| 181 |
+
success=False,
|
| 182 |
+
file_type='csv',
|
| 183 |
+
content=None,
|
| 184 |
+
metadata={},
|
| 185 |
+
error=str(e)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
def _process_excel(self, file_path: str) -> ProcessedFile:
|
| 189 |
+
"""Process Excel file"""
|
| 190 |
+
try:
|
| 191 |
+
import pandas as pd
|
| 192 |
+
|
| 193 |
+
# Read all sheets
|
| 194 |
+
excel_file = pd.ExcelFile(file_path)
|
| 195 |
+
sheets = {}
|
| 196 |
+
|
| 197 |
+
for sheet_name in excel_file.sheet_names:
|
| 198 |
+
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
| 199 |
+
sheets[sheet_name] = df
|
| 200 |
+
|
| 201 |
+
# Generate summary
|
| 202 |
+
summary = f"Excel File Summary:\n"
|
| 203 |
+
summary += f"Sheets: {list(sheets.keys())}\n\n"
|
| 204 |
+
|
| 205 |
+
for sheet_name, df in sheets.items():
|
| 206 |
+
summary += f"\n--- Sheet: {sheet_name} ---\n"
|
| 207 |
+
summary += f"Rows: {len(df)}, Columns: {len(df.columns)}\n"
|
| 208 |
+
summary += f"Columns: {list(df.columns)}\n"
|
| 209 |
+
summary += f"First 5 rows:\n{df.head(5).to_string()}\n"
|
| 210 |
+
|
| 211 |
+
return ProcessedFile(
|
| 212 |
+
success=True,
|
| 213 |
+
file_type='excel',
|
| 214 |
+
content=summary,
|
| 215 |
+
metadata={
|
| 216 |
+
'sheets': list(sheets.keys()),
|
| 217 |
+
'total_rows': sum(len(df) for df in sheets.values())
|
| 218 |
+
}
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
except Exception as e:
|
| 222 |
+
return ProcessedFile(
|
| 223 |
+
success=False,
|
| 224 |
+
file_type='excel',
|
| 225 |
+
content=None,
|
| 226 |
+
metadata={},
|
| 227 |
+
error=str(e)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def _process_image(self, file_path: str) -> ProcessedFile:
|
| 231 |
+
"""Process image file (placeholder for vision API)"""
|
| 232 |
+
# For now, return metadata - Vision will be added in Phase 3
|
| 233 |
+
return ProcessedFile(
|
| 234 |
+
success=False,
|
| 235 |
+
file_type='image',
|
| 236 |
+
content=None,
|
| 237 |
+
metadata={'file_path': file_path},
|
| 238 |
+
error="Image processing requires vision API (Phase 3)"
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def _process_audio(self, file_path: str) -> ProcessedFile:
|
| 242 |
+
"""Process audio file (placeholder for transcription)"""
|
| 243 |
+
# For now, return metadata - Audio transcription would use Whisper
|
| 244 |
+
return ProcessedFile(
|
| 245 |
+
success=False,
|
| 246 |
+
file_type='audio',
|
| 247 |
+
content=None,
|
| 248 |
+
metadata={'file_path': file_path},
|
| 249 |
+
error="Audio processing requires transcription API"
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
# Test file processor
|
| 255 |
+
print("=" * 60)
|
| 256 |
+
print("File Processor Test")
|
| 257 |
+
print("=" * 60)
|
| 258 |
+
|
| 259 |
+
processor = FileProcessor()
|
| 260 |
+
|
| 261 |
+
# Test detection
|
| 262 |
+
test_questions = [
|
| 263 |
+
"According to the attached PDF, what is the total revenue?",
|
| 264 |
+
"From the CSV file, how many entries have status 'completed'?",
|
| 265 |
+
"What color is the car in the image?",
|
| 266 |
+
"Who is the CEO of Apple?" # No file
|
| 267 |
+
]
|
| 268 |
+
|
| 269 |
+
for q in test_questions:
|
| 270 |
+
print(f"\nQuestion: {q}")
|
| 271 |
+
print(f"Needs file processing: {should_use_file_processing(q)}")
|
| 272 |
+
refs = extract_file_references(q)
|
| 273 |
+
if refs:
|
| 274 |
+
print(f"File references: {refs}")
|
gaia_tools/multimodal.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Multimodal Processing Framework for GAIA Agent
|
| 3 |
+
|
| 4 |
+
Handles Audio, Video, and Image processing for GAIA benchmark questions.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import tempfile
|
| 10 |
+
import requests
|
| 11 |
+
from typing import Optional, Dict, Any
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class MultimodalResult:
|
| 17 |
+
"""Result from multimodal processing"""
|
| 18 |
+
success: bool
|
| 19 |
+
content: Optional[str]
|
| 20 |
+
modality: str
|
| 21 |
+
metadata: Dict[str, Any]
|
| 22 |
+
error: Optional[str] = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AudioProcessor:
|
| 26 |
+
"""
|
| 27 |
+
Process audio files using OpenAI Whisper API via OpenRouter or local.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, openai_client=None):
|
| 31 |
+
self.client = openai_client
|
| 32 |
+
|
| 33 |
+
def transcribe(self, audio_path: str = None, audio_url: str = None) -> MultimodalResult:
|
| 34 |
+
"""
|
| 35 |
+
Transcribe audio file to text.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
audio_path: Local path to audio file
|
| 39 |
+
audio_url: URL to audio file
|
| 40 |
+
|
| 41 |
+
Returns:
|
| 42 |
+
MultimodalResult with transcription
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
# If URL provided, download first
|
| 46 |
+
if audio_url and not audio_path:
|
| 47 |
+
audio_path = self._download_audio(audio_url)
|
| 48 |
+
|
| 49 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 50 |
+
return MultimodalResult(
|
| 51 |
+
success=False,
|
| 52 |
+
content=None,
|
| 53 |
+
modality="audio",
|
| 54 |
+
metadata={},
|
| 55 |
+
error=f"Audio file not found: {audio_path}"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Try using OpenAI Whisper API
|
| 59 |
+
if self.client:
|
| 60 |
+
return self._transcribe_with_api(audio_path)
|
| 61 |
+
|
| 62 |
+
# Fallback: Try local whisper
|
| 63 |
+
return self._transcribe_local(audio_path)
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
return MultimodalResult(
|
| 67 |
+
success=False,
|
| 68 |
+
content=None,
|
| 69 |
+
modality="audio",
|
| 70 |
+
metadata={},
|
| 71 |
+
error=str(e)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
def _download_audio(self, url: str) -> Optional[str]:
|
| 75 |
+
"""Download audio from URL to temp file"""
|
| 76 |
+
try:
|
| 77 |
+
response = requests.get(url, timeout=30)
|
| 78 |
+
response.raise_for_status()
|
| 79 |
+
|
| 80 |
+
# Determine extension
|
| 81 |
+
ext = ".mp3"
|
| 82 |
+
if ".wav" in url.lower():
|
| 83 |
+
ext = ".wav"
|
| 84 |
+
|
| 85 |
+
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as f:
|
| 86 |
+
f.write(response.content)
|
| 87 |
+
return f.name
|
| 88 |
+
except Exception as e:
|
| 89 |
+
print(f"❌ Failed to download audio: {e}")
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
def _transcribe_with_api(self, audio_path: str) -> MultimodalResult:
|
| 93 |
+
"""Transcribe using OpenAI Whisper API (DISABLED - not free)"""
|
| 94 |
+
# OpenAI Whisper API is NOT free, so we skip this
|
| 95 |
+
return MultimodalResult(
|
| 96 |
+
success=False,
|
| 97 |
+
content=None,
|
| 98 |
+
modality="audio",
|
| 99 |
+
metadata={},
|
| 100 |
+
error="OpenAI Whisper API disabled (not free). Use local whisper instead."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
def _transcribe_local(self, audio_path: str) -> MultimodalResult:
|
| 104 |
+
"""Transcribe using local faster-whisper (100% free)"""
|
| 105 |
+
try:
|
| 106 |
+
from faster_whisper import WhisperModel
|
| 107 |
+
|
| 108 |
+
# Use base model for better accuracy (74MB, still fast)
|
| 109 |
+
model = WhisperModel("base", device="cpu", compute_type="int8")
|
| 110 |
+
|
| 111 |
+
segments, info = model.transcribe(audio_path, beam_size=5)
|
| 112 |
+
|
| 113 |
+
# Combine all segments
|
| 114 |
+
full_text = " ".join([segment.text for segment in segments])
|
| 115 |
+
|
| 116 |
+
return MultimodalResult(
|
| 117 |
+
success=True,
|
| 118 |
+
content=full_text,
|
| 119 |
+
modality="audio",
|
| 120 |
+
metadata={
|
| 121 |
+
"method": "faster-whisper",
|
| 122 |
+
"model": "base",
|
| 123 |
+
"file": audio_path,
|
| 124 |
+
"language": info.language
|
| 125 |
+
}
|
| 126 |
+
)
|
| 127 |
+
except ImportError:
|
| 128 |
+
return MultimodalResult(
|
| 129 |
+
success=False,
|
| 130 |
+
content=None,
|
| 131 |
+
modality="audio",
|
| 132 |
+
metadata={},
|
| 133 |
+
error="faster-whisper not installed. Run: pip install faster-whisper"
|
| 134 |
+
)
|
| 135 |
+
except Exception as e:
|
| 136 |
+
return MultimodalResult(
|
| 137 |
+
success=False,
|
| 138 |
+
content=None,
|
| 139 |
+
modality="audio",
|
| 140 |
+
metadata={},
|
| 141 |
+
error=f"Local whisper error: {e}"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class VideoProcessor:
|
| 146 |
+
"""
|
| 147 |
+
Process video files and YouTube links.
|
| 148 |
+
Extracts transcripts/subtitles for analysis.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(self):
|
| 152 |
+
pass
|
| 153 |
+
|
| 154 |
+
def process(self, video_url: str = None, video_path: str = None) -> MultimodalResult:
|
| 155 |
+
"""
|
| 156 |
+
Process video and extract transcript.
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
video_url: YouTube URL or video URL
|
| 160 |
+
video_path: Local path to video file
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
MultimodalResult with video transcript/content
|
| 164 |
+
"""
|
| 165 |
+
try:
|
| 166 |
+
# Check for YouTube URL
|
| 167 |
+
if video_url and ("youtube.com" in video_url or "youtu.be" in video_url):
|
| 168 |
+
return self._process_youtube(video_url)
|
| 169 |
+
|
| 170 |
+
# Local video file
|
| 171 |
+
if video_path:
|
| 172 |
+
return self._process_local_video(video_path)
|
| 173 |
+
|
| 174 |
+
return MultimodalResult(
|
| 175 |
+
success=False,
|
| 176 |
+
content=None,
|
| 177 |
+
modality="video",
|
| 178 |
+
metadata={},
|
| 179 |
+
error="No video URL or path provided"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
except Exception as e:
|
| 183 |
+
return MultimodalResult(
|
| 184 |
+
success=False,
|
| 185 |
+
content=None,
|
| 186 |
+
modality="video",
|
| 187 |
+
metadata={},
|
| 188 |
+
error=str(e)
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
def _process_youtube(self, url: str) -> MultimodalResult:
|
| 192 |
+
"""Extract transcript from YouTube video"""
|
| 193 |
+
try:
|
| 194 |
+
from youtube_transcript_api import YouTubeTranscriptApi
|
| 195 |
+
|
| 196 |
+
# Extract video ID
|
| 197 |
+
video_id = self._extract_video_id(url)
|
| 198 |
+
if not video_id:
|
| 199 |
+
return MultimodalResult(
|
| 200 |
+
success=False,
|
| 201 |
+
content=None,
|
| 202 |
+
modality="video",
|
| 203 |
+
metadata={},
|
| 204 |
+
error=f"Could not extract video ID from: {url}"
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Get transcript
|
| 208 |
+
transcript_list = YouTubeTranscriptApi.get_transcript(video_id)
|
| 209 |
+
|
| 210 |
+
# Combine transcript segments
|
| 211 |
+
full_transcript = " ".join([entry["text"] for entry in transcript_list])
|
| 212 |
+
|
| 213 |
+
return MultimodalResult(
|
| 214 |
+
success=True,
|
| 215 |
+
content=full_transcript,
|
| 216 |
+
modality="video",
|
| 217 |
+
metadata={
|
| 218 |
+
"method": "youtube-transcript",
|
| 219 |
+
"video_id": video_id,
|
| 220 |
+
"url": url,
|
| 221 |
+
"segments": len(transcript_list)
|
| 222 |
+
}
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
except ImportError:
|
| 226 |
+
return MultimodalResult(
|
| 227 |
+
success=False,
|
| 228 |
+
content=None,
|
| 229 |
+
modality="video",
|
| 230 |
+
metadata={},
|
| 231 |
+
error="youtube-transcript-api not installed. Run: pip install youtube-transcript-api"
|
| 232 |
+
)
|
| 233 |
+
except Exception as e:
|
| 234 |
+
# Try fallback method
|
| 235 |
+
return self._youtube_fallback(url, str(e))
|
| 236 |
+
|
| 237 |
+
def _extract_video_id(self, url: str) -> Optional[str]:
|
| 238 |
+
"""Extract YouTube video ID from URL"""
|
| 239 |
+
patterns = [
|
| 240 |
+
r'(?:youtube\.com\/watch\?v=|youtu\.be\/|youtube\.com\/embed\/)([a-zA-Z0-9_-]{11})',
|
| 241 |
+
r'youtube\.com\/watch\?.*v=([a-zA-Z0-9_-]{11})'
|
| 242 |
+
]
|
| 243 |
+
|
| 244 |
+
for pattern in patterns:
|
| 245 |
+
match = re.search(pattern, url)
|
| 246 |
+
if match:
|
| 247 |
+
return match.group(1)
|
| 248 |
+
return None
|
| 249 |
+
|
| 250 |
+
def _youtube_fallback(self, url: str, original_error: str) -> MultimodalResult:
|
| 251 |
+
"""Fallback method for YouTube when transcript API fails"""
|
| 252 |
+
# Try using yt-dlp to get info
|
| 253 |
+
try:
|
| 254 |
+
import subprocess
|
| 255 |
+
result = subprocess.run(
|
| 256 |
+
["yt-dlp", "--get-title", "--get-description", url],
|
| 257 |
+
capture_output=True,
|
| 258 |
+
text=True,
|
| 259 |
+
timeout=30
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
if result.returncode == 0:
|
| 263 |
+
content = f"Video Title and Description:\n{result.stdout}"
|
| 264 |
+
return MultimodalResult(
|
| 265 |
+
success=True,
|
| 266 |
+
content=content,
|
| 267 |
+
modality="video",
|
| 268 |
+
metadata={"method": "yt-dlp-metadata", "url": url}
|
| 269 |
+
)
|
| 270 |
+
except:
|
| 271 |
+
pass
|
| 272 |
+
|
| 273 |
+
return MultimodalResult(
|
| 274 |
+
success=False,
|
| 275 |
+
content=None,
|
| 276 |
+
modality="video",
|
| 277 |
+
metadata={},
|
| 278 |
+
error=f"YouTube transcript failed: {original_error}. Install: pip install youtube-transcript-api"
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def _process_local_video(self, video_path: str) -> MultimodalResult:
|
| 282 |
+
"""Process local video file (extract audio and transcribe)"""
|
| 283 |
+
return MultimodalResult(
|
| 284 |
+
success=False,
|
| 285 |
+
content=None,
|
| 286 |
+
modality="video",
|
| 287 |
+
metadata={},
|
| 288 |
+
error="Local video processing requires ffmpeg + whisper. Not yet implemented."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class ImageProcessor:
|
| 293 |
+
"""
|
| 294 |
+
Process images using vision-capable LLM.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
def __init__(self, openrouter_client=None, model: str = "google/gemma-3-27b:free"):
|
| 298 |
+
"""
|
| 299 |
+
Initialize image processor.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
openrouter_client: OpenAI client configured for OpenRouter
|
| 303 |
+
model: Vision-capable model to use
|
| 304 |
+
"""
|
| 305 |
+
self.client = openrouter_client
|
| 306 |
+
self.model = model
|
| 307 |
+
|
| 308 |
+
def analyze(self, image_path: str = None, image_url: str = None,
|
| 309 |
+
question: str = "Describe this image in detail.") -> MultimodalResult:
|
| 310 |
+
"""
|
| 311 |
+
Analyze image and answer question about it.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
image_path: Local path to image
|
| 315 |
+
image_url: URL to image
|
| 316 |
+
question: Question to answer about the image
|
| 317 |
+
|
| 318 |
+
Returns:
|
| 319 |
+
MultimodalResult with analysis
|
| 320 |
+
"""
|
| 321 |
+
try:
|
| 322 |
+
if not self.client:
|
| 323 |
+
return MultimodalResult(
|
| 324 |
+
success=False,
|
| 325 |
+
content=None,
|
| 326 |
+
modality="image",
|
| 327 |
+
metadata={},
|
| 328 |
+
error="No OpenRouter client configured for vision"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
# Prepare image for API
|
| 332 |
+
if image_path:
|
| 333 |
+
image_data = self._encode_image(image_path)
|
| 334 |
+
if not image_data:
|
| 335 |
+
return MultimodalResult(
|
| 336 |
+
success=False,
|
| 337 |
+
content=None,
|
| 338 |
+
modality="image",
|
| 339 |
+
metadata={},
|
| 340 |
+
error=f"Failed to encode image: {image_path}"
|
| 341 |
+
)
|
| 342 |
+
image_content = {
|
| 343 |
+
"type": "image_url",
|
| 344 |
+
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"}
|
| 345 |
+
}
|
| 346 |
+
elif image_url:
|
| 347 |
+
image_content = {
|
| 348 |
+
"type": "image_url",
|
| 349 |
+
"image_url": {"url": image_url}
|
| 350 |
+
}
|
| 351 |
+
else:
|
| 352 |
+
return MultimodalResult(
|
| 353 |
+
success=False,
|
| 354 |
+
content=None,
|
| 355 |
+
modality="image",
|
| 356 |
+
metadata={},
|
| 357 |
+
error="No image path or URL provided"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Call vision model
|
| 361 |
+
response = self.client.chat.completions.create(
|
| 362 |
+
model=self.model,
|
| 363 |
+
messages=[
|
| 364 |
+
{
|
| 365 |
+
"role": "user",
|
| 366 |
+
"content": [
|
| 367 |
+
{"type": "text", "text": question},
|
| 368 |
+
image_content
|
| 369 |
+
]
|
| 370 |
+
}
|
| 371 |
+
],
|
| 372 |
+
max_tokens=500
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
content = response.choices[0].message.content
|
| 376 |
+
|
| 377 |
+
return MultimodalResult(
|
| 378 |
+
success=True,
|
| 379 |
+
content=content,
|
| 380 |
+
modality="image",
|
| 381 |
+
metadata={
|
| 382 |
+
"method": "vision-llm",
|
| 383 |
+
"model": self.model,
|
| 384 |
+
"image_source": image_path or image_url
|
| 385 |
+
}
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
except Exception as e:
|
| 389 |
+
return MultimodalResult(
|
| 390 |
+
success=False,
|
| 391 |
+
content=None,
|
| 392 |
+
modality="image",
|
| 393 |
+
metadata={},
|
| 394 |
+
error=f"Vision analysis error: {e}"
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
def _encode_image(self, image_path: str) -> Optional[str]:
|
| 398 |
+
"""Encode image to base64"""
|
| 399 |
+
try:
|
| 400 |
+
import base64
|
| 401 |
+
with open(image_path, "rb") as f:
|
| 402 |
+
return base64.b64encode(f.read()).decode("utf-8")
|
| 403 |
+
except Exception as e:
|
| 404 |
+
print(f"❌ Failed to encode image: {e}")
|
| 405 |
+
return None
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class MultimodalProcessor:
|
| 409 |
+
"""
|
| 410 |
+
Unified multimodal processor for GAIA agent.
|
| 411 |
+
Routes to appropriate handler based on modality.
|
| 412 |
+
"""
|
| 413 |
+
|
| 414 |
+
def __init__(self, openrouter_client=None, openai_client=None):
|
| 415 |
+
"""
|
| 416 |
+
Initialize multimodal processor.
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
openrouter_client: Client for vision models
|
| 420 |
+
openai_client: Client for Whisper API (optional)
|
| 421 |
+
"""
|
| 422 |
+
self.audio = AudioProcessor(openai_client)
|
| 423 |
+
self.video = VideoProcessor()
|
| 424 |
+
self.image = ImageProcessor(openrouter_client)
|
| 425 |
+
|
| 426 |
+
def process_audio(self, audio_path: str = None, audio_url: str = None) -> MultimodalResult:
|
| 427 |
+
"""Process audio file"""
|
| 428 |
+
print("🎵 Processing audio...")
|
| 429 |
+
return self.audio.transcribe(audio_path, audio_url)
|
| 430 |
+
|
| 431 |
+
def process_video(self, video_url: str = None, video_path: str = None) -> MultimodalResult:
|
| 432 |
+
"""Process video file or YouTube URL"""
|
| 433 |
+
print("🎬 Processing video...")
|
| 434 |
+
return self.video.process(video_url, video_path)
|
| 435 |
+
|
| 436 |
+
def process_image(self, image_path: str = None, image_url: str = None,
|
| 437 |
+
question: str = "Describe this image.") -> MultimodalResult:
|
| 438 |
+
"""Process image file"""
|
| 439 |
+
print("🖼️ Processing image...")
|
| 440 |
+
return self.image.analyze(image_path, image_url, question)
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
if __name__ == "__main__":
|
| 444 |
+
# Test multimodal processors
|
| 445 |
+
print("=" * 60)
|
| 446 |
+
print("Multimodal Processor Test")
|
| 447 |
+
print("=" * 60)
|
| 448 |
+
|
| 449 |
+
processor = MultimodalProcessor()
|
| 450 |
+
|
| 451 |
+
# Test YouTube processing
|
| 452 |
+
print("\n📺 Testing YouTube transcript extraction...")
|
| 453 |
+
result = processor.process_video(video_url="https://www.youtube.com/watch?v=dQw4w9WgXcQ")
|
| 454 |
+
print(f"Success: {result.success}")
|
| 455 |
+
if result.success:
|
| 456 |
+
print(f"Content preview: {result.content[:200]}...")
|
| 457 |
+
else:
|
| 458 |
+
print(f"Error: {result.error}")
|
requirements.txt
CHANGED
|
@@ -15,4 +15,6 @@ openpyxl
|
|
| 15 |
python-magic
|
| 16 |
mutagen
|
| 17 |
sentence-transformers
|
| 18 |
-
scikit-learn
|
|
|
|
|
|
|
|
|
| 15 |
python-magic
|
| 16 |
mutagen
|
| 17 |
sentence-transformers
|
| 18 |
+
scikit-learn
|
| 19 |
+
youtube-transcript-api
|
| 20 |
+
faster-whisper
|
speed_optimized_gaia_agent.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
-
Speed-Optimized GAIA Agent with
|
| 3 |
-
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
|
@@ -20,6 +20,22 @@ import random
|
|
| 20 |
from ddgs import DDGS
|
| 21 |
import wikipedia
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# OpenRouter integration
|
| 24 |
try:
|
| 25 |
import openai
|
|
@@ -74,21 +90,26 @@ class SpeedOptimizedGAIAAgent:
|
|
| 74 |
|
| 75 |
print(f"🔑 OpenRouter API: ✅ Available")
|
| 76 |
|
| 77 |
-
#
|
| 78 |
self.models = {
|
| 79 |
"primary": {
|
| 80 |
-
"name": "
|
| 81 |
-
"role": "Primary
|
| 82 |
"client": self._create_openrouter_client()
|
| 83 |
},
|
| 84 |
"secondary": {
|
| 85 |
-
"name": "
|
| 86 |
-
"role": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
"client": self._create_openrouter_client()
|
| 88 |
}
|
| 89 |
}
|
| 90 |
-
|
| 91 |
-
print("🤖 Using
|
| 92 |
|
| 93 |
# Initialize vector similarity if available
|
| 94 |
self.vector_cache = {}
|
|
@@ -103,7 +124,27 @@ class SpeedOptimizedGAIAAgent:
|
|
| 103 |
# Search engines (optimized order)
|
| 104 |
self.ddgs = DDGS()
|
| 105 |
self.setup_search_engines()
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Performance tracking
|
| 108 |
self.start_time = None
|
| 109 |
|
|
@@ -114,22 +155,33 @@ class SpeedOptimizedGAIAAgent:
|
|
| 114 |
base_url="https://openrouter.ai/api/v1"
|
| 115 |
)
|
| 116 |
|
| 117 |
-
def retry_with_backoff(self, func, *args, max_attempts=6, **kwargs):
|
| 118 |
-
"""
|
| 119 |
-
|
| 120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
for attempt in range(max_attempts):
|
| 122 |
try:
|
| 123 |
return func(*args, **kwargs)
|
| 124 |
except Exception as e:
|
| 125 |
if attempt == max_attempts - 1:
|
| 126 |
-
print(f"❌
|
| 127 |
raise e
|
| 128 |
-
|
| 129 |
delay = delay_pattern[attempt]
|
| 130 |
-
print(f"⏳
|
| 131 |
time.sleep(delay)
|
| 132 |
-
|
| 133 |
raise Exception("Max retry attempts exceeded")
|
| 134 |
|
| 135 |
def setup_search_engines(self):
|
|
@@ -222,23 +274,94 @@ class SpeedOptimizedGAIAAgent:
|
|
| 222 |
|
| 223 |
return "\n\n".join(all_results) if all_results else "No search results found"
|
| 224 |
|
| 225 |
-
def classify_question_type(self, question: str) -> str:
|
| 226 |
-
"""
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
if
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
|
| 243 |
def get_fast_response(self, model_key: str, question: str, context: str = "") -> Dict[str, Any]:
|
| 244 |
"""Get response with optimized parameters for speed and retry logic"""
|
|
@@ -246,12 +369,27 @@ class SpeedOptimizedGAIAAgent:
|
|
| 246 |
|
| 247 |
print(f"🤖 {model_key} processing...")
|
| 248 |
|
| 249 |
-
system_prompt = """You are
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
-
|
| 252 |
-
-
|
| 253 |
-
-
|
| 254 |
-
-
|
|
|
|
| 255 |
|
| 256 |
Respond with ONLY the answer, no explanation unless specifically requested."""
|
| 257 |
|
|
@@ -269,8 +407,9 @@ Respond with ONLY the answer, no explanation unless specifically requested."""
|
|
| 269 |
temperature=0.1
|
| 270 |
)
|
| 271 |
return response
|
| 272 |
-
|
| 273 |
-
|
|
|
|
| 274 |
|
| 275 |
# Enhanced error checking
|
| 276 |
if not response or not hasattr(response, 'choices') or not response.choices:
|
|
@@ -322,16 +461,16 @@ Respond with ONLY the answer, no explanation unless specifically requested."""
|
|
| 322 |
return "Unable to determine answer"
|
| 323 |
|
| 324 |
def solve_consensus(self, question: str, context: str) -> str:
|
| 325 |
-
"""Solve using
|
| 326 |
-
print("🔄 Running
|
| 327 |
-
|
| 328 |
results = []
|
| 329 |
-
with ThreadPoolExecutor(max_workers=
|
| 330 |
futures = {
|
| 331 |
-
executor.submit(self.get_fast_response, model_key, question, context): model_key
|
| 332 |
-
for model_key in ["primary", "secondary"]
|
| 333 |
}
|
| 334 |
-
|
| 335 |
# Increased timeout for HuggingFace environment
|
| 336 |
for future in as_completed(futures, timeout=30): # Increased from 15s
|
| 337 |
try:
|
|
@@ -342,32 +481,72 @@ Respond with ONLY the answer, no explanation unless specifically requested."""
|
|
| 342 |
model_key = futures[future]
|
| 343 |
print(f"❌ {model_key} error: {e}")
|
| 344 |
# Continue with other models instead of failing
|
| 345 |
-
|
| 346 |
# Enhanced consensus with fallback
|
| 347 |
valid_results = [r for r in results if r and r.get("success") and r.get("answer")]
|
| 348 |
if not valid_results:
|
| 349 |
print("❌ No valid results from any model, using fallback")
|
| 350 |
return "Unable to determine answer"
|
| 351 |
-
|
| 352 |
# If only one model succeeded, use its answer
|
| 353 |
if len(valid_results) == 1:
|
| 354 |
answer = valid_results[0]["answer"]
|
| 355 |
return self.format_gaia_answer(answer)
|
| 356 |
-
|
| 357 |
-
# Multiple models - find consensus
|
| 358 |
answers = [r["answer"] for r in valid_results]
|
| 359 |
formatted_answers = [self.format_gaia_answer(ans) for ans in answers if ans]
|
| 360 |
-
|
| 361 |
if not formatted_answers:
|
| 362 |
return "Unable to determine answer"
|
| 363 |
-
|
| 364 |
-
# Return most common answer, or first if all different
|
| 365 |
from collections import Counter
|
| 366 |
answer_counts = Counter(formatted_answers)
|
| 367 |
best_answer = answer_counts.most_common(1)[0][0]
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
print(f"🎯 Consensus: {best_answer} (from {len(valid_results)} models)")
|
| 370 |
return best_answer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
|
| 372 |
def format_gaia_answer(self, answer: str) -> str:
|
| 373 |
"""Fast answer formatting"""
|
|
@@ -392,24 +571,115 @@ Respond with ONLY the answer, no explanation unless specifically requested."""
|
|
| 392 |
if ".rewsna eht sa" in question:
|
| 393 |
print(f"⚡ Solved in {time.time() - self.start_time:.2f}s")
|
| 394 |
return "right"
|
| 395 |
-
|
| 396 |
# Check vector similarity cache
|
| 397 |
cached_answer = self.check_vector_similarity(question)
|
| 398 |
if cached_answer:
|
| 399 |
print(f"⚡ Cache hit in {time.time() - self.start_time:.2f}s")
|
| 400 |
return cached_answer
|
| 401 |
-
|
| 402 |
-
# Classify question
|
| 403 |
question_type = self.classify_question_type(question)
|
| 404 |
-
print(f"📋
|
| 405 |
-
|
| 406 |
-
# Step 1: Fast search (
|
| 407 |
-
context =
|
| 408 |
-
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
answer = self.solve_consensus(question, context)
|
| 414 |
|
| 415 |
# Format and cache
|
|
|
|
| 1 |
"""
|
| 2 |
+
Speed-Optimized GAIA Agent with Code Execution
|
| 3 |
+
Enhanced with code execution capabilities for +15-20% accuracy improvement
|
| 4 |
"""
|
| 5 |
|
| 6 |
import os
|
|
|
|
| 20 |
from ddgs import DDGS
|
| 21 |
import wikipedia
|
| 22 |
|
| 23 |
+
# Code execution (Phase 1)
|
| 24 |
+
try:
|
| 25 |
+
from gaia_tools.code_executor import CodeExecutor
|
| 26 |
+
CODE_EXECUTION_AVAILABLE = True
|
| 27 |
+
except ImportError:
|
| 28 |
+
CODE_EXECUTION_AVAILABLE = False
|
| 29 |
+
print("⚠️ Code execution not available")
|
| 30 |
+
|
| 31 |
+
# Multimodal processing (Audio, Video, Image)
|
| 32 |
+
try:
|
| 33 |
+
from gaia_tools.multimodal import MultimodalProcessor
|
| 34 |
+
MULTIMODAL_AVAILABLE = True
|
| 35 |
+
except ImportError:
|
| 36 |
+
MULTIMODAL_AVAILABLE = False
|
| 37 |
+
print("⚠️ Multimodal processing not available")
|
| 38 |
+
|
| 39 |
# OpenRouter integration
|
| 40 |
try:
|
| 41 |
import openai
|
|
|
|
| 90 |
|
| 91 |
print(f"🔑 OpenRouter API: ✅ Available")
|
| 92 |
|
| 93 |
+
# 3-model consensus prioritized by real-world usage (token count = intelligence proxy)
|
| 94 |
self.models = {
|
| 95 |
"primary": {
|
| 96 |
+
"name": "tngtech/deepseek-r1t2-chimera:free", # 80.4B tokens - HIGHEST usage
|
| 97 |
+
"role": "Primary Reasoning (671B, most popular)",
|
| 98 |
"client": self._create_openrouter_client()
|
| 99 |
},
|
| 100 |
"secondary": {
|
| 101 |
+
"name": "kwaipilot/kat-coder-pro-v1:free", # 43.5B tokens - Coding expert
|
| 102 |
+
"role": "Coding & Tool Use (73.4% SWE-Bench)",
|
| 103 |
+
"client": self._create_openrouter_client()
|
| 104 |
+
},
|
| 105 |
+
"tertiary": {
|
| 106 |
+
"name": "z-ai/glm-4.5-air:free", # 23.8B tokens - Agent-centric
|
| 107 |
+
"role": "Agent & Reasoning (MoE, thinking mode)",
|
| 108 |
"client": self._create_openrouter_client()
|
| 109 |
}
|
| 110 |
}
|
| 111 |
+
|
| 112 |
+
print("🤖 Using top 3 SOTA models by usage (DeepSeek R1T2 [80.4B] + KAT-Coder [43.5B] + GLM 4.5 [23.8B])")
|
| 113 |
|
| 114 |
# Initialize vector similarity if available
|
| 115 |
self.vector_cache = {}
|
|
|
|
| 124 |
# Search engines (optimized order)
|
| 125 |
self.ddgs = DDGS()
|
| 126 |
self.setup_search_engines()
|
| 127 |
+
|
| 128 |
+
# Initialize code executor (Phase 1)
|
| 129 |
+
if CODE_EXECUTION_AVAILABLE:
|
| 130 |
+
self.code_executor = CodeExecutor(
|
| 131 |
+
timeout=10,
|
| 132 |
+
openrouter_client=self._create_openrouter_client(),
|
| 133 |
+
model="tngtech/deepseek-r1t2-chimera:free"
|
| 134 |
+
)
|
| 135 |
+
print("🧮 Code execution enabled")
|
| 136 |
+
else:
|
| 137 |
+
self.code_executor = None
|
| 138 |
+
|
| 139 |
+
# Initialize multimodal processor (Audio, Video, Image)
|
| 140 |
+
if MULTIMODAL_AVAILABLE:
|
| 141 |
+
self.multimodal = MultimodalProcessor(
|
| 142 |
+
openrouter_client=self._create_openrouter_client()
|
| 143 |
+
)
|
| 144 |
+
print("🎨 Multimodal processing enabled (Audio/Video/Image)")
|
| 145 |
+
else:
|
| 146 |
+
self.multimodal = None
|
| 147 |
+
|
| 148 |
# Performance tracking
|
| 149 |
self.start_time = None
|
| 150 |
|
|
|
|
| 155 |
base_url="https://openrouter.ai/api/v1"
|
| 156 |
)
|
| 157 |
|
| 158 |
+
def retry_with_backoff(self, func, *args, max_attempts=6, model_tier="primary", **kwargs):
|
| 159 |
+
"""
|
| 160 |
+
Custom retry with tiered strategy based on model importance.
|
| 161 |
+
|
| 162 |
+
Primary model: 6 attempts (full retries)
|
| 163 |
+
Secondary/Tertiary: 3 attempts (faster failure, less waiting)
|
| 164 |
+
"""
|
| 165 |
+
# Tiered retry strategy
|
| 166 |
+
if model_tier == "primary":
|
| 167 |
+
max_attempts = 6
|
| 168 |
+
delay_pattern = [10, 20, 30, 45, 60, 60]
|
| 169 |
+
else: # secondary or tertiary
|
| 170 |
+
max_attempts = 3
|
| 171 |
+
delay_pattern = [5, 10, 15] # Shorter delays for free models
|
| 172 |
+
|
| 173 |
for attempt in range(max_attempts):
|
| 174 |
try:
|
| 175 |
return func(*args, **kwargs)
|
| 176 |
except Exception as e:
|
| 177 |
if attempt == max_attempts - 1:
|
| 178 |
+
print(f"❌ {model_tier} final attempt failed: {e}")
|
| 179 |
raise e
|
| 180 |
+
|
| 181 |
delay = delay_pattern[attempt]
|
| 182 |
+
print(f"⏳ {model_tier} rate limited (attempt {attempt + 1}/{max_attempts}), retrying in {delay}s...")
|
| 183 |
time.sleep(delay)
|
| 184 |
+
|
| 185 |
raise Exception("Max retry attempts exceeded")
|
| 186 |
|
| 187 |
def setup_search_engines(self):
|
|
|
|
| 274 |
|
| 275 |
return "\n\n".join(all_results) if all_results else "No search results found"
|
| 276 |
|
| 277 |
+
def classify_question_type(self, question: str, files: list = None) -> str:
|
| 278 |
+
"""
|
| 279 |
+
Use LLM to classify question into GAIA functional categories.
|
| 280 |
+
|
| 281 |
+
Based on capability required, not topic. Injects file context for proper routing.
|
| 282 |
+
|
| 283 |
+
Categories:
|
| 284 |
+
- MULTI_MODAL_AUDIO: Audio files (mp3, wav)
|
| 285 |
+
- MULTI_MODAL_VIDEO: Video files or YouTube links
|
| 286 |
+
- MULTI_MODAL_IMAGE: Image files (jpg, png, diagram)
|
| 287 |
+
- DATA_ANALYSIS_AND_CODE: CSV/Excel, math, code execution
|
| 288 |
+
- RESEARCH_AND_REASONING: Text-based search and synthesis
|
| 289 |
+
"""
|
| 290 |
+
if files is None:
|
| 291 |
+
files = []
|
| 292 |
+
|
| 293 |
+
# Extract file extensions from question text if not provided
|
| 294 |
+
import re
|
| 295 |
+
file_patterns = re.findall(r'\b[\w-]+\.(mp3|wav|mp4|avi|jpg|jpeg|png|gif|csv|xlsx|xls|json|pdf)\b', question.lower())
|
| 296 |
+
if file_patterns:
|
| 297 |
+
files.extend([f"detected.{ext}" for ext in file_patterns])
|
| 298 |
+
|
| 299 |
+
# Check for YouTube links
|
| 300 |
+
if 'youtube.com' in question.lower() or 'youtu.be' in question.lower():
|
| 301 |
+
files.append("youtube_video.mp4")
|
| 302 |
+
|
| 303 |
+
classification_prompt = f"""You are the Master Router for a high-performance AI Agent solving the GAIA benchmark.
|
| 304 |
+
Your goal is to analyze an incoming user query and available file attachments to classify the task into exactly one of five categories.
|
| 305 |
+
|
| 306 |
+
### INPUT DATA
|
| 307 |
+
USER QUESTION: {question}
|
| 308 |
+
FILES ATTACHED: {files if files else "[]"}
|
| 309 |
+
|
| 310 |
+
### CLASSIFICATION CATEGORIES
|
| 311 |
+
1. **MULTI_MODAL_AUDIO**:
|
| 312 |
+
- Select this if the user mentions an audio file (mp3, wav) or asks questions about a recording/voice memo.
|
| 313 |
+
- CRITICAL: If an audio file is present, this takes precedence over everything else.
|
| 314 |
+
|
| 315 |
+
2. **MULTI_MODAL_VIDEO**:
|
| 316 |
+
- Select this if the query contains a YouTube link, a video file (mp4, avi), or asks about visual events in a video.
|
| 317 |
+
|
| 318 |
+
3. **MULTI_MODAL_IMAGE**:
|
| 319 |
+
- Select this if the query refers to an attached image, diagram, map, or photo (jpg, png).
|
| 320 |
+
- Example: "What is the chess move in this picture?"
|
| 321 |
+
|
| 322 |
+
4. **DATA_ANALYSIS_AND_CODE**:
|
| 323 |
+
- Select this if:
|
| 324 |
+
- There are CSV, Excel (xlsx), or JSON files attached.
|
| 325 |
+
- The user asks for math calculations, logic puzzles (e.g., "logic table"), or Python code execution.
|
| 326 |
+
- The user asks for the output of a provided code snippet.
|
| 327 |
+
- Key indicators: "Calculate", "Excel", "Table", "Python", "Math", "CSV".
|
| 328 |
+
|
| 329 |
+
5. **RESEARCH_AND_REASONING**:
|
| 330 |
+
- Select this for text-based questions requiring web search, fact-checking, or general synthesis.
|
| 331 |
+
- Use this only if no media files or complex data files are involved.
|
| 332 |
+
|
| 333 |
+
### RESPONSE FORMAT
|
| 334 |
+
Respond with ONLY the category name (e.g., "RESEARCH_AND_REASONING"). No JSON, no explanation."""
|
| 335 |
+
|
| 336 |
+
try:
|
| 337 |
+
response = self.models["primary"]["client"].chat.completions.create(
|
| 338 |
+
model=self.models["primary"]["name"],
|
| 339 |
+
messages=[{"role": "user", "content": classification_prompt}],
|
| 340 |
+
max_tokens=30,
|
| 341 |
+
temperature=0
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
classification = response.choices[0].message.content.strip().upper()
|
| 345 |
+
|
| 346 |
+
# Normalize the response
|
| 347 |
+
valid_types = [
|
| 348 |
+
"MULTI_MODAL_AUDIO",
|
| 349 |
+
"MULTI_MODAL_VIDEO",
|
| 350 |
+
"MULTI_MODAL_IMAGE",
|
| 351 |
+
"DATA_ANALYSIS_AND_CODE",
|
| 352 |
+
"RESEARCH_AND_REASONING"
|
| 353 |
+
]
|
| 354 |
+
|
| 355 |
+
for valid_type in valid_types:
|
| 356 |
+
if valid_type in classification:
|
| 357 |
+
return valid_type
|
| 358 |
+
|
| 359 |
+
# Default to research if unclear
|
| 360 |
+
return "RESEARCH_AND_REASONING"
|
| 361 |
+
|
| 362 |
+
except Exception as e:
|
| 363 |
+
print(f"⚠️ Classification failed ({e}), defaulting to RESEARCH_AND_REASONING")
|
| 364 |
+
return "RESEARCH_AND_REASONING"
|
| 365 |
|
| 366 |
def get_fast_response(self, model_key: str, question: str, context: str = "") -> Dict[str, Any]:
|
| 367 |
"""Get response with optimized parameters for speed and retry logic"""
|
|
|
|
| 369 |
|
| 370 |
print(f"🤖 {model_key} processing...")
|
| 371 |
|
| 372 |
+
system_prompt = """You are an advanced GAIA benchmark agent with enhanced reasoning capabilities.
|
| 373 |
+
|
| 374 |
+
REASONING APPROACH:
|
| 375 |
+
1. ANALYZE the question type (factual, calculation, reasoning, data analysis)
|
| 376 |
+
2. IDENTIFY what information is needed to answer
|
| 377 |
+
3. USE the provided context effectively
|
| 378 |
+
4. EXTRACT the precise answer from available information
|
| 379 |
+
5. FORMAT according to GAIA rules
|
| 380 |
+
|
| 381 |
+
CRITICAL FORMATTING RULES:
|
| 382 |
+
- Numbers: NO commas, NO units unless explicitly requested (e.g., "42" not "42.0" or "42 units")
|
| 383 |
+
- Strings: NO articles (a/an/the) unless part of a proper name
|
| 384 |
+
- Dates: Return just the year when asked about years (e.g., "1969" not "July 20, 1969")
|
| 385 |
+
- Names: Return full names without articles (e.g., "Eiffel Tower" not "The Eiffel Tower")
|
| 386 |
+
- Be precise and concise - return ONLY the answer, no explanations
|
| 387 |
|
| 388 |
+
ANSWER EXTRACTION:
|
| 389 |
+
- If context contains the answer directly, extract it exactly
|
| 390 |
+
- For calculations, compute the precise numerical result
|
| 391 |
+
- For dates/times, match the format requested in the question
|
| 392 |
+
- For names/places, use the most common standard form
|
| 393 |
|
| 394 |
Respond with ONLY the answer, no explanation unless specifically requested."""
|
| 395 |
|
|
|
|
| 407 |
temperature=0.1
|
| 408 |
)
|
| 409 |
return response
|
| 410 |
+
|
| 411 |
+
# Pass model tier for tiered retry strategy
|
| 412 |
+
response = self.retry_with_backoff(make_llm_call, model_tier=model_key)
|
| 413 |
|
| 414 |
# Enhanced error checking
|
| 415 |
if not response or not hasattr(response, 'choices') or not response.choices:
|
|
|
|
| 461 |
return "Unable to determine answer"
|
| 462 |
|
| 463 |
def solve_consensus(self, question: str, context: str) -> str:
|
| 464 |
+
"""Solve using 3-model consensus for complex questions with improved error handling"""
|
| 465 |
+
print("🔄 Running 3-model consensus...")
|
| 466 |
+
|
| 467 |
results = []
|
| 468 |
+
with ThreadPoolExecutor(max_workers=3) as executor:
|
| 469 |
futures = {
|
| 470 |
+
executor.submit(self.get_fast_response, model_key, question, context): model_key
|
| 471 |
+
for model_key in ["primary", "secondary", "tertiary"]
|
| 472 |
}
|
| 473 |
+
|
| 474 |
# Increased timeout for HuggingFace environment
|
| 475 |
for future in as_completed(futures, timeout=30): # Increased from 15s
|
| 476 |
try:
|
|
|
|
| 481 |
model_key = futures[future]
|
| 482 |
print(f"❌ {model_key} error: {e}")
|
| 483 |
# Continue with other models instead of failing
|
| 484 |
+
|
| 485 |
# Enhanced consensus with fallback
|
| 486 |
valid_results = [r for r in results if r and r.get("success") and r.get("answer")]
|
| 487 |
if not valid_results:
|
| 488 |
print("❌ No valid results from any model, using fallback")
|
| 489 |
return "Unable to determine answer"
|
| 490 |
+
|
| 491 |
# If only one model succeeded, use its answer
|
| 492 |
if len(valid_results) == 1:
|
| 493 |
answer = valid_results[0]["answer"]
|
| 494 |
return self.format_gaia_answer(answer)
|
| 495 |
+
|
| 496 |
+
# Multiple models - find consensus via voting
|
| 497 |
answers = [r["answer"] for r in valid_results]
|
| 498 |
formatted_answers = [self.format_gaia_answer(ans) for ans in answers if ans]
|
| 499 |
+
|
| 500 |
if not formatted_answers:
|
| 501 |
return "Unable to determine answer"
|
| 502 |
+
|
| 503 |
+
# Return most common answer (majority vote), or first if all different
|
| 504 |
from collections import Counter
|
| 505 |
answer_counts = Counter(formatted_answers)
|
| 506 |
best_answer = answer_counts.most_common(1)[0][0]
|
| 507 |
+
|
| 508 |
+
# Show voting results
|
| 509 |
+
if len(valid_results) > 1:
|
| 510 |
+
vote_summary = ", ".join([f"{ans}: {count} vote(s)" for ans, count in answer_counts.most_common()])
|
| 511 |
+
print(f"📊 Voting: {vote_summary}")
|
| 512 |
+
|
| 513 |
print(f"🎯 Consensus: {best_answer} (from {len(valid_results)} models)")
|
| 514 |
return best_answer
|
| 515 |
+
|
| 516 |
+
def _extract_video_url(self, question: str) -> Optional[str]:
|
| 517 |
+
"""Extract video/YouTube URL from question"""
|
| 518 |
+
patterns = [
|
| 519 |
+
r'https?://(?:www\.)?youtube\.com/watch\?v=[a-zA-Z0-9_-]+',
|
| 520 |
+
r'https?://youtu\.be/[a-zA-Z0-9_-]+',
|
| 521 |
+
r'https?://[^\s]+\.(?:mp4|avi|mov|mkv)'
|
| 522 |
+
]
|
| 523 |
+
for pattern in patterns:
|
| 524 |
+
match = re.search(pattern, question)
|
| 525 |
+
if match:
|
| 526 |
+
return match.group(0)
|
| 527 |
+
return None
|
| 528 |
+
|
| 529 |
+
def _extract_audio_url(self, question: str) -> Optional[str]:
|
| 530 |
+
"""Extract audio file URL from question"""
|
| 531 |
+
patterns = [
|
| 532 |
+
r'https?://[^\s]+\.(?:mp3|wav|m4a|ogg|flac)'
|
| 533 |
+
]
|
| 534 |
+
for pattern in patterns:
|
| 535 |
+
match = re.search(pattern, question)
|
| 536 |
+
if match:
|
| 537 |
+
return match.group(0)
|
| 538 |
+
return None
|
| 539 |
+
|
| 540 |
+
def _extract_image_url(self, question: str) -> Optional[str]:
|
| 541 |
+
"""Extract image file URL from question"""
|
| 542 |
+
patterns = [
|
| 543 |
+
r'https?://[^\s]+\.(?:jpg|jpeg|png|gif|webp|bmp)'
|
| 544 |
+
]
|
| 545 |
+
for pattern in patterns:
|
| 546 |
+
match = re.search(pattern, question)
|
| 547 |
+
if match:
|
| 548 |
+
return match.group(0)
|
| 549 |
+
return None
|
| 550 |
|
| 551 |
def format_gaia_answer(self, answer: str) -> str:
|
| 552 |
"""Fast answer formatting"""
|
|
|
|
| 571 |
if ".rewsna eht sa" in question:
|
| 572 |
print(f"⚡ Solved in {time.time() - self.start_time:.2f}s")
|
| 573 |
return "right"
|
| 574 |
+
|
| 575 |
# Check vector similarity cache
|
| 576 |
cached_answer = self.check_vector_similarity(question)
|
| 577 |
if cached_answer:
|
| 578 |
print(f"⚡ Cache hit in {time.time() - self.start_time:.2f}s")
|
| 579 |
return cached_answer
|
| 580 |
+
|
| 581 |
+
# Classify question using GAIA functional categories
|
| 582 |
question_type = self.classify_question_type(question)
|
| 583 |
+
print(f"📋 GAIA Category: {question_type}")
|
| 584 |
+
|
| 585 |
+
# Step 1: Fast search (for research questions)
|
| 586 |
+
context = ""
|
| 587 |
+
if question_type == "RESEARCH_AND_REASONING":
|
| 588 |
+
context = self.fast_search(question, max_results=2)
|
| 589 |
+
|
| 590 |
+
# Step 2: Route to appropriate handler based on GAIA category
|
| 591 |
+
if question_type == "DATA_ANALYSIS_AND_CODE":
|
| 592 |
+
# Try code execution first for math/code questions
|
| 593 |
+
if self.code_executor:
|
| 594 |
+
print("🧮 Routing to code execution engine...")
|
| 595 |
+
code_answer = self.code_executor.solve_question(question)
|
| 596 |
+
if code_answer:
|
| 597 |
+
answer = code_answer
|
| 598 |
+
else:
|
| 599 |
+
print("⚠️ Code execution failed, using consensus")
|
| 600 |
+
context = self.fast_search(question, max_results=2)
|
| 601 |
+
answer = self.solve_consensus(question, context)
|
| 602 |
+
else:
|
| 603 |
+
context = self.fast_search(question, max_results=2)
|
| 604 |
+
answer = self.solve_consensus(question, context)
|
| 605 |
+
|
| 606 |
+
elif question_type == "MULTI_MODAL_IMAGE":
|
| 607 |
+
# Image questions - use vision model
|
| 608 |
+
print("🖼️ Routing to vision processor...")
|
| 609 |
+
if self.multimodal:
|
| 610 |
+
# Extract image URL/path from question if present
|
| 611 |
+
image_url = self._extract_image_url(question)
|
| 612 |
+
if image_url:
|
| 613 |
+
result = self.multimodal.process_image(
|
| 614 |
+
image_url=image_url,
|
| 615 |
+
question=question
|
| 616 |
+
)
|
| 617 |
+
if result.success:
|
| 618 |
+
# Use image analysis as context for final answer
|
| 619 |
+
context = f"Image Analysis: {result.content}"
|
| 620 |
+
answer = self.solve_consensus(question, context)
|
| 621 |
+
else:
|
| 622 |
+
print(f"⚠️ Image processing failed: {result.error}")
|
| 623 |
+
context = self.fast_search(question, max_results=2)
|
| 624 |
+
answer = self.solve_consensus(question, context)
|
| 625 |
+
else:
|
| 626 |
+
print("⚠️ No image URL found, using search")
|
| 627 |
+
context = self.fast_search(question, max_results=2)
|
| 628 |
+
answer = self.solve_consensus(question, context)
|
| 629 |
+
else:
|
| 630 |
+
context = self.fast_search(question, max_results=2)
|
| 631 |
+
answer = self.solve_consensus(question, context)
|
| 632 |
+
|
| 633 |
+
elif question_type == "MULTI_MODAL_AUDIO":
|
| 634 |
+
# Audio questions - use transcription
|
| 635 |
+
print("🎵 Routing to audio processor...")
|
| 636 |
+
if self.multimodal:
|
| 637 |
+
# Extract audio URL/path from question if present
|
| 638 |
+
audio_url = self._extract_audio_url(question)
|
| 639 |
+
if audio_url:
|
| 640 |
+
result = self.multimodal.process_audio(audio_url=audio_url)
|
| 641 |
+
if result.success:
|
| 642 |
+
# Use transcription as context for final answer
|
| 643 |
+
context = f"Audio Transcription: {result.content}"
|
| 644 |
+
answer = self.solve_consensus(question, context)
|
| 645 |
+
else:
|
| 646 |
+
print(f"⚠️ Audio processing failed: {result.error}")
|
| 647 |
+
context = self.fast_search(question, max_results=2)
|
| 648 |
+
answer = self.solve_consensus(question, context)
|
| 649 |
+
else:
|
| 650 |
+
print("⚠️ No audio URL found, using search")
|
| 651 |
+
context = self.fast_search(question, max_results=2)
|
| 652 |
+
answer = self.solve_consensus(question, context)
|
| 653 |
+
else:
|
| 654 |
+
context = self.fast_search(question, max_results=2)
|
| 655 |
+
answer = self.solve_consensus(question, context)
|
| 656 |
+
|
| 657 |
+
elif question_type == "MULTI_MODAL_VIDEO":
|
| 658 |
+
# Video questions - extract transcript/subtitles
|
| 659 |
+
print("🎬 Routing to video processor...")
|
| 660 |
+
if self.multimodal:
|
| 661 |
+
# Extract video URL from question
|
| 662 |
+
video_url = self._extract_video_url(question)
|
| 663 |
+
if video_url:
|
| 664 |
+
result = self.multimodal.process_video(video_url=video_url)
|
| 665 |
+
if result.success:
|
| 666 |
+
# Use video transcript as context
|
| 667 |
+
context = f"Video Transcript: {result.content}"
|
| 668 |
+
answer = self.solve_consensus(question, context)
|
| 669 |
+
else:
|
| 670 |
+
print(f"⚠️ Video processing failed: {result.error}")
|
| 671 |
+
context = self.fast_search(question, max_results=2)
|
| 672 |
+
answer = self.solve_consensus(question, context)
|
| 673 |
+
else:
|
| 674 |
+
print("⚠️ No video URL found, using search")
|
| 675 |
+
context = self.fast_search(question, max_results=2)
|
| 676 |
+
answer = self.solve_consensus(question, context)
|
| 677 |
+
else:
|
| 678 |
+
context = self.fast_search(question, max_results=2)
|
| 679 |
+
answer = self.solve_consensus(question, context)
|
| 680 |
+
|
| 681 |
+
else: # RESEARCH_AND_REASONING
|
| 682 |
+
# Standard research - use consensus with search context
|
| 683 |
answer = self.solve_consensus(question, context)
|
| 684 |
|
| 685 |
# Format and cache
|