fa_agents / test_bench.py
j14i's picture
Got 45%
e04e3db
"""
GAIA Test Bench - Local evaluation for your agent
Usage:
uv run python test_bench.py # Run on 5 random questions
uv run python test_bench.py --n 10 # Run on 10 random questions
uv run python test_bench.py --level 1 # Run on level 1 only
uv run python test_bench.py --level 1 --n 3 # Run on 3 level 1 questions
uv run python test_bench.py --all # Run on all validation questions
uv run python test_bench.py --task-id 1234 # Run on specific task ID
uv run python test_bench.py --type youtube,file,web,excel,pdf,image,audio,text-only
"""
import argparse
import json
import os
import re
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from datasets import load_dataset
from dotenv import load_dotenv
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field
load_dotenv()
class AnnotatorMetadata(BaseModel):
"""Metadata from GAIA annotators about how to solve the task."""
steps: str = Field(default="", alias="Steps")
number_of_steps: str = Field(default="", alias="Number of steps")
tools: str = Field(default="", alias="Tools")
number_of_tools: str = Field(default="", alias="Number of tools")
how_long: str = Field(default="", alias="How long did this take?")
class GaiaQuestion(BaseModel):
"""A single question from the GAIA dataset."""
model_config = {"populate_by_name": True}
task_id: str
question: str = Field(alias="Question")
level: int = Field(alias="Level")
final_answer: str = Field(default="", alias="Final answer")
file_name: str = ""
file_path: str | None = None
annotator: AnnotatorMetadata = Field(
default_factory=AnnotatorMetadata, alias="Annotator Metadata"
)
@dataclass
class TestResult:
"""Result of running an agent on a GAIA question."""
task_id: str
question: str
expected: str
actual: str
correct: bool
level: int
file_name: str = ""
file_path: str | None = None
annotator_steps: str = ""
annotator_tools: str = ""
annotator_number_of_steps: str = ""
def normalize_answer(answer: str) -> str:
"""Normalize answer for comparison."""
if not answer:
return ""
ans = answer.lower().strip()
ans = re.sub(r"[.,;:!?]+$", "", ans)
return ans
def answers_match(expected: str, actual: str) -> bool:
"""Check if answers match (with some flexibility)."""
exp = normalize_answer(expected)
act = normalize_answer(actual)
if exp == act:
return True
if exp in act:
return True
try:
exp_num = float(re.sub(r"[^\d.-]", "", exp))
act_num = float(re.sub(r"[^\d.-]", "", act))
if abs(exp_num - act_num) < 0.01:
return True
except (ValueError, TypeError):
pass
return False
def filter_by_type(
questions: list[GaiaQuestion], task_type: str | None
) -> list[GaiaQuestion]:
"""Filter questions by task type."""
if not task_type:
return questions
filtered = []
for q in questions:
question_text = q.question.lower()
file_name = q.file_name.lower()
match task_type:
case "youtube":
if "youtube.com" in question_text or "youtu.be" in question_text:
filtered.append(q)
case "file":
if q.file_name:
filtered.append(q)
case "web":
if "http://" in question_text or "https://" in question_text:
filtered.append(q)
case "excel":
if file_name.endswith((".xlsx", ".xls", ".csv")):
filtered.append(q)
case "pdf":
if file_name.endswith(".pdf"):
filtered.append(q)
case "image":
if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")):
filtered.append(q)
case "audio":
if file_name.endswith((".mp3", ".wav", ".m4a", ".flac", ".ogg")):
filtered.append(q)
case "text-only":
has_url = "http://" in question_text or "https://" in question_text
has_file = bool(q.file_name)
if not has_url and not has_file:
filtered.append(q)
case _:
filtered.append(q)
return filtered
def filter_by_task_id(
questions: list[GaiaQuestion], task_id: str | None
) -> list[GaiaQuestion]:
"""Filter questions by task id."""
if not task_id:
return questions
filtered = []
for q in questions:
if q.task_id == task_id:
filtered.append(q)
return filtered
def load_gaia_data(level: int | None = None) -> list[GaiaQuestion]:
"""Load GAIA validation dataset with all metadata."""
print("Downloading GAIA dataset...")
data_dir = snapshot_download(repo_id="gaia-benchmark/GAIA", repo_type="dataset")
questions: list[GaiaQuestion] = []
levels = [level] if level else [1, 2, 3]
for lvl in levels:
try:
dataset = load_dataset(data_dir, f"2023_level{lvl}", split="validation")
for example in dataset:
question = GaiaQuestion.model_validate(example)
# Fix file_path to be absolute
if question.file_path:
question.file_path = os.path.join(data_dir, question.file_path)
questions.append(question)
except Exception as e:
print(f"Warning: Could not load level {lvl}: {e}")
print(f"Loaded {len(questions)} questions")
return questions
def run_test_bench(
agent,
n_questions: int | None = 5,
level: int | None = None,
task_type: str | None = None,
run_all: bool = False,
save_results: bool = True,
task_id: str | None = None,
) -> list[TestResult]:
"""
Run the test bench on the agent.
Args:
agent: The agent to test (callable that takes question string)
n_questions: Number of questions to test (None for all)
level: Filter by difficulty level (1, 2, or 3)
task_type: Filter by task type (youtube, file, web, excel, pdf, image, audio, text-only)
run_all: If True, run on all questions
save_results: Save results to JSON file
Returns:
List of TestResult objects
"""
import random
questions = load_gaia_data(level=level)
questions = filter_by_type(questions, task_type)
questions = filter_by_task_id(questions, task_id)
if not questions:
print(f"No questions found for type '{task_type}'")
return []
if not run_all and n_questions and n_questions < len(questions):
questions = random.sample(questions, n_questions)
results: list[TestResult] = []
correct_count = 0
print(f"\n{'='*60}")
print(f"Running {len(questions)} questions...")
print("=" * 60)
for i, q in enumerate(questions, 1):
print(f"\n[{i}/{len(questions)}] Level {q.level}: {q.question[:80]}...")
if q.file_name:
print(f" File: {q.file_name}")
if q.annotator.tools:
print(f" Tools needed:\n{q.annotator.tools}")
try:
question_with_file = q.question
if q.file_path:
question_with_file += f"\n\nFile path: {q.file_path}"
import asyncio
actual = asyncio.run(agent(question_with_file))
except Exception as e:
actual = f"ERROR: {e}"
is_correct = answers_match(q.final_answer, actual)
if is_correct:
correct_count += 1
result = TestResult(
task_id=q.task_id,
question=q.question,
expected=q.final_answer,
actual=actual,
correct=is_correct,
level=q.level,
file_name=q.file_name,
file_path=q.file_path,
annotator_steps=q.annotator.steps,
annotator_tools=q.annotator.tools,
annotator_number_of_steps=q.annotator.number_of_steps,
)
results.append(result)
status = "CORRECT" if is_correct else "WRONG"
print(f" Expected: {q.final_answer}")
print(f" Got: {actual[:100]}...")
print(f" Status: {status}")
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print("=" * 60)
print(f"Total: {len(results)}")
print(f"Correct: {correct_count}")
print(f"Accuracy: {correct_count/len(results)*100:.1f}%")
# Per-level breakdown
for lvl in [1, 2, 3]:
lvl_results = [r for r in results if r.level == lvl]
if lvl_results:
lvl_correct = sum(1 for r in lvl_results if r.correct)
print(
f" Level {lvl}: {lvl_correct}/{len(lvl_results)} ({lvl_correct/len(lvl_results)*100:.1f}%)"
)
# Tool usage analysis
print(f"\n{'='*60}")
print("TOOL REQUIREMENTS ANALYSIS")
print("=" * 60)
tool_stats: dict[str, dict] = {}
for r in results:
tools = (
r.annotator_tools
if isinstance(r.annotator_tools, str)
else ", ".join(r.annotator_tools)
)
if tools:
if tools not in tool_stats:
tool_stats[tools] = {"total": 0, "correct": 0}
tool_stats[tools]["total"] += 1
if r.correct:
tool_stats[tools]["correct"] += 1
for tools, stats in sorted(tool_stats.items(), key=lambda x: -x[1]["total"]):
acc = stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0
print(f" {tools}: {stats['correct']}/{stats['total']} ({acc:.0f}%)")
# Save results
if save_results:
results_dir = Path("test_results")
results_dir.mkdir(exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = results_dir / f"results_{timestamp}.json"
with open(results_file, "w") as f:
json.dump(
{
"timestamp": timestamp,
"total": len(results),
"correct": correct_count,
"accuracy": correct_count / len(results),
"results": [
{
"task_id": r.task_id,
"level": r.level,
"question": r.question,
"expected": r.expected,
"actual": r.actual,
"correct": r.correct,
"file_name": r.file_name,
"file_path": r.file_path,
"annotator_steps": r.annotator_steps,
"annotator_tools": r.annotator_tools,
"annotator_number_of_steps": r.annotator_number_of_steps,
}
for r in results
],
},
f,
indent=2,
)
print(f"\nResults saved to: {results_file}")
return results
def main():
parser = argparse.ArgumentParser(description="GAIA Test Bench")
parser.add_argument("--n", type=int, default=5, help="Number of questions to test")
parser.add_argument("--level", type=int, choices=[1, 2, 3], help="Filter by level")
parser.add_argument(
"--type",
choices=[
"youtube",
"file",
"web",
"excel",
"pdf",
"image",
"audio",
"text-only",
],
help="Filter by task type",
)
parser.add_argument("--all", action="store_true", help="Run all questions")
parser.add_argument("--no-save", action="store_true", help="Don't save results")
parser.add_argument("--task-id", type=str, help="Run specific task ID")
args = parser.parse_args()
from agent import BasicAgent
if callable(BasicAgent) and not hasattr(BasicAgent, "__call__"):
agent = BasicAgent(question="")
else:
agent = BasicAgent
run_test_bench(
agent=agent,
n_questions=args.n,
level=args.level,
task_type=args.type,
run_all=args.all,
save_results=not args.no_save,
task_id=args.task_id,
)
if __name__ == "__main__":
main()