HuggingFace_Agent_Cert / gaia_tools /error_analysis.py
AgileAndy's picture
updated I'm proved version
f504b2e verified
"""
GAIA Error Analysis Framework
Categorizes questions, failure modes, and generates actionable improvement recommendations.
Implements TDD test suite specifications from tests/test_error_analysis.py
"""
import csv
import json
import re
from dataclasses import dataclass, asdict
from enum import Enum
from typing import List, Dict, Optional, Any
from collections import defaultdict, Counter
class QuestionType(Enum):
"""Categories of GAIA questions"""
MATH = "math"
FILE = "file"
WEB = "web"
IMAGE = "image"
AUDIO = "audio"
REASONING = "reasoning"
MULTIMODAL = "multimodal"
UNKNOWN = "unknown"
class FailureMode(Enum):
"""Categories of answer failures"""
WRONG_ANSWER = "wrong_answer"
FORMATTING_ERROR = "formatting_error"
TIMEOUT = "timeout"
TOOL_FAILURE = "tool_failure"
EMPTY_RESPONSE = "empty_response"
@dataclass
class TestResult:
"""Represents a single test result"""
question_id: str
question: str
question_type: QuestionType
expected: str
actual: str
success: bool
failure_mode: Optional[FailureMode] = None
time_elapsed: float = 0.0
tools_used: Optional[List[str]] = None
error: Optional[Exception] = None
def __post_init__(self):
if self.tools_used is None:
self.tools_used = []
class GAIATestAnalyzer:
"""
Analyzes GAIA agent test results to identify failure patterns and recommend improvements.
This class implements error categorization, performance tracking, and reporting
to guide agent optimization efforts.
"""
def __init__(self):
self.results: List[TestResult] = []
# Patterns for question classification
self.math_patterns = [
r'\d+\s*[\+\-\*\/]\s*\d+', # Arithmetic operations with numbers
r'calculate|compute|sum|multiply|divide|subtract|add',
r'what is \d+',
r'how many|how much'
]
self.file_patterns = [
r'pdf|csv|excel|spreadsheet|document|table|file',
r'attached|according to the',
]
self.image_patterns = [
r'image|picture|photo|screenshot|attached.*color|in the (attached )?image'
]
self.audio_patterns = [
r'audio|recording|sound|said in|spoken|voice'
]
self.web_patterns = [
r'who is|what is the (current|latest)|CEO|president|founded|website',
r'according to.*wikipedia|look up'
]
self.reasoning_patterns = [
r'if .+ then|taller than|shorter than|before|after',
r'who is the (tallest|shortest|oldest|youngest)',
]
self.multimodal_patterns = [
r'(image|picture|photo).*(csv|file|data|spreadsheet)',
r'(csv|file|data|spreadsheet).*(image|picture|photo)',
r'using the .+ and the'
]
def classify_question_type(self, question: str) -> QuestionType:
"""
Classify a question into a QuestionType based on its content.
Args:
question: The question text to classify
Returns:
QuestionType enum value
"""
question_lower = question.lower()
# Check multimodal first (highest priority)
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.multimodal_patterns):
return QuestionType.MULTIMODAL
# Check for image questions
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.image_patterns):
return QuestionType.IMAGE
# Check for audio questions
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.audio_patterns):
return QuestionType.AUDIO
# Check for file questions
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.file_patterns):
return QuestionType.FILE
# Check for math questions
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.math_patterns):
return QuestionType.MATH
# Check for reasoning questions
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.reasoning_patterns):
return QuestionType.REASONING
# Check for web search questions
if any(re.search(pattern, question_lower, re.IGNORECASE)
for pattern in self.web_patterns):
return QuestionType.WEB
return QuestionType.UNKNOWN
def classify_failure_mode(
self,
expected: str,
actual: Optional[str],
error: Optional[Exception] = None
) -> FailureMode:
"""
Classify why an answer failed.
Args:
expected: The correct answer
actual: The agent's answer (None if error occurred)
error: Exception if one occurred
Returns:
FailureMode enum value
"""
# Check for exceptions first
if error is not None:
if isinstance(error, TimeoutError):
return FailureMode.TIMEOUT
else:
return FailureMode.TOOL_FAILURE
# Check for empty/unable responses
if actual is None or actual.strip() == "":
return FailureMode.EMPTY_RESPONSE
if "unable to determine" in actual.lower():
return FailureMode.EMPTY_RESPONSE
# Check for formatting errors
expected_clean = expected.strip().lower()
actual_clean = actual.strip().lower()
# Remove commas and check if answers match
expected_no_comma = expected_clean.replace(',', '')
actual_no_comma = actual_clean.replace(',', '')
if expected_no_comma == actual_no_comma and expected_clean != actual_clean:
return FailureMode.FORMATTING_ERROR
# Check for unwanted units
if actual_clean.startswith(expected_clean):
remainder = actual_clean[len(expected_clean):].strip()
if remainder: # Has extra content (likely units)
return FailureMode.FORMATTING_ERROR
# Check for articles (the, a, an)
articles = ['the ', 'a ', 'an ']
for article in articles:
if actual_clean.startswith(article):
without_article = actual_clean[len(article):]
if without_article == expected_clean:
return FailureMode.FORMATTING_ERROR
# If none of the above, it's a wrong answer
return FailureMode.WRONG_ANSWER
def log_result(self, result: TestResult):
"""
Add a test result to the analyzer.
Args:
result: TestResult object to log
"""
self.results.append(result)
def analyze_response(
self,
question_id: str,
question: str,
expected: str,
actual: str,
time_elapsed: float = 0.0,
tools_used: Optional[List[str]] = None,
error: Optional[Exception] = None
) -> TestResult:
"""
Analyze a single agent response and create a TestResult.
This is a convenience method that combines classification and logging.
Args:
question_id: Unique identifier for the question
question: The question text
expected: The correct answer
actual: The agent's answer
time_elapsed: Time taken to answer
tools_used: List of tools used by the agent
error: Exception if one occurred
Returns:
TestResult object with all classifications
"""
question_type = self.classify_question_type(question)
success = (actual == expected) if actual is not None else False
failure_mode = None
if not success:
failure_mode = self.classify_failure_mode(expected, actual, error)
result = TestResult(
question_id=question_id,
question=question,
question_type=question_type,
expected=expected,
actual=actual,
success=success,
failure_mode=failure_mode,
time_elapsed=time_elapsed,
tools_used=tools_used or [],
error=error
)
self.log_result(result)
return result
def generate_summary(self) -> Dict[str, Any]:
"""
Generate summary statistics for all logged results.
Returns:
Dictionary with summary statistics
"""
if not self.results:
return {
"total_questions": 0,
"correct_count": 0,
"accuracy": 0.0,
"avg_time": 0.0
}
total = len(self.results)
correct = sum(1 for r in self.results if r.success)
total_time = sum(r.time_elapsed for r in self.results)
return {
"total_questions": total,
"correct_count": correct,
"accuracy": correct / total if total > 0 else 0.0,
"avg_time": total_time / total if total > 0 else 0.0
}
def get_accuracy_by_type(self) -> Dict[QuestionType, float]:
"""
Calculate accuracy broken down by question type.
Returns:
Dictionary mapping QuestionType to accuracy (0.0-1.0)
"""
type_stats = defaultdict(lambda: {"correct": 0, "total": 0})
for result in self.results:
stats = type_stats[result.question_type]
stats["total"] += 1
if result.success:
stats["correct"] += 1
accuracy_by_type = {}
for qtype, stats in type_stats.items():
accuracy_by_type[qtype] = (
stats["correct"] / stats["total"] if stats["total"] > 0 else 0.0
)
return accuracy_by_type
def get_failures_by_mode(self) -> Dict[FailureMode, int]:
"""
Count failures by failure mode.
Returns:
Dictionary mapping FailureMode to count
"""
failure_counts = Counter()
for result in self.results:
if not result.success and result.failure_mode:
failure_counts[result.failure_mode] += 1
return dict(failure_counts)
def export_to_csv(self, filepath: str):
"""
Export all results to a CSV file.
Args:
filepath: Path to output CSV file
"""
with open(filepath, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
# Write header
writer.writerow([
'question_id', 'question', 'question_type', 'expected', 'actual',
'success', 'failure_mode', 'time_elapsed', 'tools_used'
])
# Write results
for result in self.results:
writer.writerow([
result.question_id,
result.question,
result.question_type.value.upper(),
result.expected,
result.actual,
result.success,
result.failure_mode.value.upper() if result.failure_mode else '',
result.time_elapsed,
','.join(result.tools_used) if result.tools_used else ''
])
def export_to_json(self, filepath: str):
"""
Export all results and summary to a JSON file.
Args:
filepath: Path to output JSON file
"""
data = {
"summary": self.generate_summary(),
"accuracy_by_type": {
qtype.value: acc
for qtype, acc in self.get_accuracy_by_type().items()
},
"failures_by_mode": {
mode.value: count
for mode, count in self.get_failures_by_mode().items()
},
"results": [
{
"question_id": r.question_id,
"question": r.question,
"question_type": r.question_type.value,
"expected": r.expected,
"actual": r.actual,
"success": r.success,
"failure_mode": r.failure_mode.value if r.failure_mode else None,
"time_elapsed": r.time_elapsed,
"tools_used": r.tools_used
}
for r in self.results
]
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2)
def get_recommendations(self) -> List[str]:
"""
Generate actionable recommendations based on failure analysis.
Returns:
List of recommendation strings
"""
recommendations = []
# Analyze question types with low accuracy
accuracy_by_type = self.get_accuracy_by_type()
failures_by_mode = self.get_failures_by_mode()
# Check for image-related failures
image_results = [r for r in self.results if r.question_type == QuestionType.IMAGE]
if image_results:
image_accuracy = accuracy_by_type.get(QuestionType.IMAGE, 0.0)
if image_accuracy < 0.5:
recommendations.append(
"Add vision capabilities (Gemini 2.5 Pro) to handle image questions"
)
# Check for file processing failures
file_results = [r for r in self.results if r.question_type == QuestionType.FILE]
if file_results:
file_accuracy = accuracy_by_type.get(QuestionType.FILE, 0.0)
if file_accuracy < 0.5:
recommendations.append(
"Implement file processing capabilities (PDF/CSV/Excel parsing)"
)
# Check for math failures
math_results = [r for r in self.results if r.question_type == QuestionType.MATH]
if math_results:
math_accuracy = accuracy_by_type.get(QuestionType.MATH, 0.0)
if math_accuracy < 0.7:
recommendations.append(
"Add code execution capabilities for reliable math calculations"
)
# Check for formatting errors
formatting_errors = failures_by_mode.get(FailureMode.FORMATTING_ERROR, 0)
if formatting_errors > len(self.results) * 0.1: # More than 10% formatting errors
recommendations.append(
"Improve answer formatting logic to handle commas, units, and articles"
)
# Check for empty responses
empty_responses = failures_by_mode.get(FailureMode.EMPTY_RESPONSE, 0)
if empty_responses > len(self.results) * 0.1:
recommendations.append(
"Improve tool reliability and add fallback mechanisms for empty responses"
)
# Check for timeouts
timeouts = failures_by_mode.get(FailureMode.TIMEOUT, 0)
if timeouts > len(self.results) * 0.05:
recommendations.append(
"Optimize query speed and increase timeout thresholds for complex questions"
)
# Check for audio processing
audio_results = [r for r in self.results if r.question_type == QuestionType.AUDIO]
if audio_results:
audio_accuracy = accuracy_by_type.get(QuestionType.AUDIO, 0.0)
if audio_accuracy < 0.5:
recommendations.append(
"Add audio transcription capabilities (Whisper)"
)
# Check for multimodal questions
multimodal_results = [r for r in self.results if r.question_type == QuestionType.MULTIMODAL]
if multimodal_results:
multimodal_accuracy = accuracy_by_type.get(QuestionType.MULTIMODAL, 0.0)
if multimodal_accuracy < 0.5:
recommendations.append(
"Improve multimodal reasoning by integrating multiple tool outputs"
)
return recommendations