derkaal's picture
Upload folder using huggingface_hub
c7eca3d verified
#!/usr/bin/env python3
"""
Utility functions for GAIA Benchmark Agent.
This module provides utility functions for progress tracking,
performance analysis, and other helper functions.
"""
import os
import json
import datetime
from typing import Dict, List, Any, Optional
from gaiaX.config import logger, CONFIG
def load_progress(progress_file: str = None) -> dict:
"""
Load progress from a JSON file.
Args:
progress_file: Path to the progress file
Returns:
Dictionary containing progress data
"""
if not progress_file:
progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
try:
if os.path.exists(progress_file):
with open(progress_file, 'r') as f:
progress = json.load(f)
return progress
else:
return {"processed_questions": [], "answers": {}}
except Exception as e:
logger.error(f"Error loading progress from {progress_file}: {e}")
return {"processed_questions": [], "answers": {}}
def save_progress(progress_data: dict, progress_file: str = None) -> bool:
"""
Save progress to a JSON file.
Args:
progress_data: Dictionary containing progress data
progress_file: Path to the progress file
Returns:
True if successful, False otherwise
"""
if not progress_file:
progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
try:
with open(progress_file, 'w') as f:
json.dump(progress_data, f, indent=2)
return True
except Exception as e:
logger.error(f"Error saving progress to {progress_file}: {e}")
return False
def analyze_performance(answers: list, expected_answers: list = None) -> dict:
"""
Analyze the performance of the agent based on answers.
Args:
answers: List of answer dictionaries
expected_answers: Optional list of expected answers for evaluation
Returns:
Dictionary containing performance metrics
"""
total_questions = len(answers)
successful_answers = sum(1 for a in answers if "error" not in a)
error_count = total_questions - successful_answers
# Calculate average response time if available
response_times = [a.get("response_time", 0) for a in answers if "response_time" in a]
avg_response_time = sum(response_times) / len(response_times) if response_times else 0
# Count question types
question_types = {}
for answer in answers:
q_type = answer.get("question_type", "unknown")
question_types[q_type] = question_types.get(q_type, 0) + 1
# Calculate accuracy if expected answers are provided
accuracy = None
correct_answers = 0
if expected_answers:
answer_dict = {a.get("task_id"): a.get("answer") for a in answers}
expected_dict = {e.get("task_id"): e.get("answer") for e in expected_answers}
common_ids = set(answer_dict.keys()) & set(expected_dict.keys())
if common_ids:
for task_id in common_ids:
if answer_dict[task_id] == expected_dict[task_id]:
correct_answers += 1
accuracy = correct_answers / len(common_ids)
# Compile metrics
metrics = {
"total_questions": total_questions,
"successful_answers": successful_answers,
"error_count": error_count,
"success_rate": successful_answers / total_questions if total_questions > 0 else 0,
"average_response_time": avg_response_time,
"question_types": question_types
}
if accuracy is not None:
metrics["accuracy"] = accuracy
metrics["correct_answers"] = correct_answers
return metrics
def format_performance_report(metrics: dict) -> str:
"""
Format performance metrics into a readable report.
Args:
metrics: Dictionary containing performance metrics
Returns:
Formatted performance report as a string
"""
report = [
"=== GAIA Benchmark Agent Performance Report ===",
f"Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
"",
f"Total Questions Processed: {metrics['total_questions']}",
f"Successful Answers: {metrics['successful_answers']} ({metrics['success_rate']:.2%})",
f"Errors: {metrics['error_count']}",
f"Average Response Time: {metrics['average_response_time']:.2f} seconds",
"",
"Question Type Distribution:"
]
# Add question type distribution
for q_type, count in metrics.get("question_types", {}).items():
percentage = count / metrics["total_questions"] if metrics["total_questions"] > 0 else 0
report.append(f" - {q_type}: {count} ({percentage:.2%})")
# Add accuracy information if available
if "accuracy" in metrics:
report.extend([
"",
f"Accuracy: {metrics['accuracy']:.2%}",
f"Correct Answers: {metrics['correct_answers']} out of {metrics['total_questions']}"
])
return "\n".join(report)
def process_questions_batch(agent: Any, questions: list, api_base_url: str,
progress_file: str = None, batch_size: int = 10) -> dict:
"""
Process a batch of questions and track progress.
Args:
agent: Initialized LangChain agent
questions: List of question dictionaries
api_base_url: Base URL for the GAIA API
progress_file: Path to the progress file
batch_size: Number of questions to process in each batch
Returns:
Dictionary containing processed questions and answers
"""
from gaiaX.question_handlers import process_question
# Load existing progress if available
if not progress_file:
progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
progress = {}
try:
if os.path.exists(progress_file):
with open(progress_file, 'r') as f:
progress = json.load(f)
else:
progress = {"processed_questions": [], "answers": {}}
except Exception as e:
logger.error(f"Error loading progress from {progress_file}: {e}")
progress = {"processed_questions": [], "answers": {}}
# Get list of already processed questions
processed_ids = set(progress.get("processed_questions", []))
# Filter out already processed questions
remaining_questions = [q for q in questions if q.get("task_id") not in processed_ids]
logger.info(f"Found {len(remaining_questions)} questions to process out of {len(questions)} total")
# Process questions in batches
results = []
for i, question in enumerate(remaining_questions):
if i > 0 and i % batch_size == 0:
logger.info(f"Processed {i}/{len(remaining_questions)} questions. Saving progress...")
save_progress(progress, progress_file)
try:
task_id = question.get("task_id")
logger.info(f"Processing question {i+1}/{len(remaining_questions)}: {task_id}")
# Process the question
start_time = datetime.datetime.now()
result = process_question(agent, question, api_base_url)
end_time = datetime.datetime.now()
# Calculate response time
response_time = (end_time - start_time).total_seconds()
result["response_time"] = response_time
# Add to results and update progress
results.append(result)
progress["processed_questions"].append(task_id)
progress["answers"][task_id] = result.get("answer")
logger.info(f"Completed question {task_id} in {response_time:.2f} seconds")
except Exception as e:
logger.error(f"Error processing question: {str(e)}")
results.append({
"task_id": question.get("task_id", ""),
"question": question.get("question", ""),
"answer": f"Error: {str(e)}",
"error": str(e)
})
# Save final progress
save_progress(progress, progress_file)
return {
"results": results,
"progress": progress
}