|
|
|
|
|
"""
|
|
|
Utility functions for GAIA Benchmark Agent.
|
|
|
|
|
|
This module provides utility functions for progress tracking,
|
|
|
performance analysis, and other helper functions.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import json
|
|
|
import datetime
|
|
|
from typing import Dict, List, Any, Optional
|
|
|
|
|
|
from gaiaX.config import logger, CONFIG
|
|
|
|
|
|
def load_progress(progress_file: str = None) -> dict:
|
|
|
"""
|
|
|
Load progress from a JSON file.
|
|
|
|
|
|
Args:
|
|
|
progress_file: Path to the progress file
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing progress data
|
|
|
"""
|
|
|
if not progress_file:
|
|
|
progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
|
|
|
|
|
|
try:
|
|
|
if os.path.exists(progress_file):
|
|
|
with open(progress_file, 'r') as f:
|
|
|
progress = json.load(f)
|
|
|
return progress
|
|
|
else:
|
|
|
return {"processed_questions": [], "answers": {}}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading progress from {progress_file}: {e}")
|
|
|
return {"processed_questions": [], "answers": {}}
|
|
|
|
|
|
|
|
|
def save_progress(progress_data: dict, progress_file: str = None) -> bool:
|
|
|
"""
|
|
|
Save progress to a JSON file.
|
|
|
|
|
|
Args:
|
|
|
progress_data: Dictionary containing progress data
|
|
|
progress_file: Path to the progress file
|
|
|
|
|
|
Returns:
|
|
|
True if successful, False otherwise
|
|
|
"""
|
|
|
if not progress_file:
|
|
|
progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
|
|
|
|
|
|
try:
|
|
|
with open(progress_file, 'w') as f:
|
|
|
json.dump(progress_data, f, indent=2)
|
|
|
return True
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error saving progress to {progress_file}: {e}")
|
|
|
return False
|
|
|
|
|
|
|
|
|
def analyze_performance(answers: list, expected_answers: list = None) -> dict:
|
|
|
"""
|
|
|
Analyze the performance of the agent based on answers.
|
|
|
|
|
|
Args:
|
|
|
answers: List of answer dictionaries
|
|
|
expected_answers: Optional list of expected answers for evaluation
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing performance metrics
|
|
|
"""
|
|
|
total_questions = len(answers)
|
|
|
successful_answers = sum(1 for a in answers if "error" not in a)
|
|
|
error_count = total_questions - successful_answers
|
|
|
|
|
|
|
|
|
response_times = [a.get("response_time", 0) for a in answers if "response_time" in a]
|
|
|
avg_response_time = sum(response_times) / len(response_times) if response_times else 0
|
|
|
|
|
|
|
|
|
question_types = {}
|
|
|
for answer in answers:
|
|
|
q_type = answer.get("question_type", "unknown")
|
|
|
question_types[q_type] = question_types.get(q_type, 0) + 1
|
|
|
|
|
|
|
|
|
accuracy = None
|
|
|
correct_answers = 0
|
|
|
if expected_answers:
|
|
|
answer_dict = {a.get("task_id"): a.get("answer") for a in answers}
|
|
|
expected_dict = {e.get("task_id"): e.get("answer") for e in expected_answers}
|
|
|
|
|
|
common_ids = set(answer_dict.keys()) & set(expected_dict.keys())
|
|
|
if common_ids:
|
|
|
for task_id in common_ids:
|
|
|
if answer_dict[task_id] == expected_dict[task_id]:
|
|
|
correct_answers += 1
|
|
|
accuracy = correct_answers / len(common_ids)
|
|
|
|
|
|
|
|
|
metrics = {
|
|
|
"total_questions": total_questions,
|
|
|
"successful_answers": successful_answers,
|
|
|
"error_count": error_count,
|
|
|
"success_rate": successful_answers / total_questions if total_questions > 0 else 0,
|
|
|
"average_response_time": avg_response_time,
|
|
|
"question_types": question_types
|
|
|
}
|
|
|
|
|
|
if accuracy is not None:
|
|
|
metrics["accuracy"] = accuracy
|
|
|
metrics["correct_answers"] = correct_answers
|
|
|
|
|
|
return metrics
|
|
|
|
|
|
|
|
|
def format_performance_report(metrics: dict) -> str:
|
|
|
"""
|
|
|
Format performance metrics into a readable report.
|
|
|
|
|
|
Args:
|
|
|
metrics: Dictionary containing performance metrics
|
|
|
|
|
|
Returns:
|
|
|
Formatted performance report as a string
|
|
|
"""
|
|
|
report = [
|
|
|
"=== GAIA Benchmark Agent Performance Report ===",
|
|
|
f"Generated: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
|
|
|
"",
|
|
|
f"Total Questions Processed: {metrics['total_questions']}",
|
|
|
f"Successful Answers: {metrics['successful_answers']} ({metrics['success_rate']:.2%})",
|
|
|
f"Errors: {metrics['error_count']}",
|
|
|
f"Average Response Time: {metrics['average_response_time']:.2f} seconds",
|
|
|
"",
|
|
|
"Question Type Distribution:"
|
|
|
]
|
|
|
|
|
|
|
|
|
for q_type, count in metrics.get("question_types", {}).items():
|
|
|
percentage = count / metrics["total_questions"] if metrics["total_questions"] > 0 else 0
|
|
|
report.append(f" - {q_type}: {count} ({percentage:.2%})")
|
|
|
|
|
|
|
|
|
if "accuracy" in metrics:
|
|
|
report.extend([
|
|
|
"",
|
|
|
f"Accuracy: {metrics['accuracy']:.2%}",
|
|
|
f"Correct Answers: {metrics['correct_answers']} out of {metrics['total_questions']}"
|
|
|
])
|
|
|
|
|
|
return "\n".join(report)
|
|
|
|
|
|
|
|
|
def process_questions_batch(agent: Any, questions: list, api_base_url: str,
|
|
|
progress_file: str = None, batch_size: int = 10) -> dict:
|
|
|
"""
|
|
|
Process a batch of questions and track progress.
|
|
|
|
|
|
Args:
|
|
|
agent: Initialized LangChain agent
|
|
|
questions: List of question dictionaries
|
|
|
api_base_url: Base URL for the GAIA API
|
|
|
progress_file: Path to the progress file
|
|
|
batch_size: Number of questions to process in each batch
|
|
|
|
|
|
Returns:
|
|
|
Dictionary containing processed questions and answers
|
|
|
"""
|
|
|
from gaiaX.question_handlers import process_question
|
|
|
|
|
|
|
|
|
if not progress_file:
|
|
|
progress_file = CONFIG.get("paths", {}).get("progress_file", "gaia_progress.json")
|
|
|
|
|
|
progress = {}
|
|
|
try:
|
|
|
if os.path.exists(progress_file):
|
|
|
with open(progress_file, 'r') as f:
|
|
|
progress = json.load(f)
|
|
|
else:
|
|
|
progress = {"processed_questions": [], "answers": {}}
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error loading progress from {progress_file}: {e}")
|
|
|
progress = {"processed_questions": [], "answers": {}}
|
|
|
|
|
|
|
|
|
processed_ids = set(progress.get("processed_questions", []))
|
|
|
|
|
|
|
|
|
remaining_questions = [q for q in questions if q.get("task_id") not in processed_ids]
|
|
|
logger.info(f"Found {len(remaining_questions)} questions to process out of {len(questions)} total")
|
|
|
|
|
|
|
|
|
results = []
|
|
|
for i, question in enumerate(remaining_questions):
|
|
|
if i > 0 and i % batch_size == 0:
|
|
|
logger.info(f"Processed {i}/{len(remaining_questions)} questions. Saving progress...")
|
|
|
save_progress(progress, progress_file)
|
|
|
|
|
|
try:
|
|
|
task_id = question.get("task_id")
|
|
|
logger.info(f"Processing question {i+1}/{len(remaining_questions)}: {task_id}")
|
|
|
|
|
|
|
|
|
start_time = datetime.datetime.now()
|
|
|
result = process_question(agent, question, api_base_url)
|
|
|
end_time = datetime.datetime.now()
|
|
|
|
|
|
|
|
|
response_time = (end_time - start_time).total_seconds()
|
|
|
result["response_time"] = response_time
|
|
|
|
|
|
|
|
|
results.append(result)
|
|
|
progress["processed_questions"].append(task_id)
|
|
|
progress["answers"][task_id] = result.get("answer")
|
|
|
|
|
|
logger.info(f"Completed question {task_id} in {response_time:.2f} seconds")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error processing question: {str(e)}")
|
|
|
results.append({
|
|
|
"task_id": question.get("task_id", ""),
|
|
|
"question": question.get("question", ""),
|
|
|
"answer": f"Error: {str(e)}",
|
|
|
"error": str(e)
|
|
|
})
|
|
|
|
|
|
|
|
|
save_progress(progress, progress_file)
|
|
|
|
|
|
return {
|
|
|
"results": results,
|
|
|
"progress": progress
|
|
|
} |