""" GAIA Test Bench - Local evaluation for your agent Usage: uv run python test_bench.py # Run on 5 random questions uv run python test_bench.py --n 10 # Run on 10 random questions uv run python test_bench.py --level 1 # Run on level 1 only uv run python test_bench.py --level 1 --n 3 # Run on 3 level 1 questions uv run python test_bench.py --all # Run on all validation questions uv run python test_bench.py --task-id 1234 # Run on specific task ID uv run python test_bench.py --type youtube,file,web,excel,pdf,image,audio,text-only """ import argparse import json import os import re from dataclasses import dataclass from datetime import datetime from pathlib import Path from datasets import load_dataset from dotenv import load_dotenv from huggingface_hub import snapshot_download from pydantic import BaseModel, Field load_dotenv() class AnnotatorMetadata(BaseModel): """Metadata from GAIA annotators about how to solve the task.""" steps: str = Field(default="", alias="Steps") number_of_steps: str = Field(default="", alias="Number of steps") tools: str = Field(default="", alias="Tools") number_of_tools: str = Field(default="", alias="Number of tools") how_long: str = Field(default="", alias="How long did this take?") class GaiaQuestion(BaseModel): """A single question from the GAIA dataset.""" model_config = {"populate_by_name": True} task_id: str question: str = Field(alias="Question") level: int = Field(alias="Level") final_answer: str = Field(default="", alias="Final answer") file_name: str = "" file_path: str | None = None annotator: AnnotatorMetadata = Field( default_factory=AnnotatorMetadata, alias="Annotator Metadata" ) @dataclass class TestResult: """Result of running an agent on a GAIA question.""" task_id: str question: str expected: str actual: str correct: bool level: int file_name: str = "" file_path: str | None = None annotator_steps: str = "" annotator_tools: str = "" annotator_number_of_steps: str = "" def normalize_answer(answer: str) -> str: """Normalize answer for comparison.""" if not answer: return "" ans = answer.lower().strip() ans = re.sub(r"[.,;:!?]+$", "", ans) return ans def answers_match(expected: str, actual: str) -> bool: """Check if answers match (with some flexibility).""" exp = normalize_answer(expected) act = normalize_answer(actual) if exp == act: return True if exp in act: return True try: exp_num = float(re.sub(r"[^\d.-]", "", exp)) act_num = float(re.sub(r"[^\d.-]", "", act)) if abs(exp_num - act_num) < 0.01: return True except (ValueError, TypeError): pass return False def filter_by_type( questions: list[GaiaQuestion], task_type: str | None ) -> list[GaiaQuestion]: """Filter questions by task type.""" if not task_type: return questions filtered = [] for q in questions: question_text = q.question.lower() file_name = q.file_name.lower() match task_type: case "youtube": if "youtube.com" in question_text or "youtu.be" in question_text: filtered.append(q) case "file": if q.file_name: filtered.append(q) case "web": if "http://" in question_text or "https://" in question_text: filtered.append(q) case "excel": if file_name.endswith((".xlsx", ".xls", ".csv")): filtered.append(q) case "pdf": if file_name.endswith(".pdf"): filtered.append(q) case "image": if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")): filtered.append(q) case "audio": if file_name.endswith((".mp3", ".wav", ".m4a", ".flac", ".ogg")): filtered.append(q) case "text-only": has_url = "http://" in question_text or "https://" in question_text has_file = bool(q.file_name) if not has_url and not has_file: filtered.append(q) case _: filtered.append(q) return filtered def filter_by_task_id( questions: list[GaiaQuestion], task_id: str | None ) -> list[GaiaQuestion]: """Filter questions by task id.""" if not task_id: return questions filtered = [] for q in questions: if q.task_id == task_id: filtered.append(q) return filtered def load_gaia_data(level: int | None = None) -> list[GaiaQuestion]: """Load GAIA validation dataset with all metadata.""" print("Downloading GAIA dataset...") data_dir = snapshot_download(repo_id="gaia-benchmark/GAIA", repo_type="dataset") questions: list[GaiaQuestion] = [] levels = [level] if level else [1, 2, 3] for lvl in levels: try: dataset = load_dataset(data_dir, f"2023_level{lvl}", split="validation") for example in dataset: question = GaiaQuestion.model_validate(example) # Fix file_path to be absolute if question.file_path: question.file_path = os.path.join(data_dir, question.file_path) questions.append(question) except Exception as e: print(f"Warning: Could not load level {lvl}: {e}") print(f"Loaded {len(questions)} questions") return questions def run_test_bench( agent, n_questions: int | None = 5, level: int | None = None, task_type: str | None = None, run_all: bool = False, save_results: bool = True, task_id: str | None = None, ) -> list[TestResult]: """ Run the test bench on the agent. Args: agent: The agent to test (callable that takes question string) n_questions: Number of questions to test (None for all) level: Filter by difficulty level (1, 2, or 3) task_type: Filter by task type (youtube, file, web, excel, pdf, image, audio, text-only) run_all: If True, run on all questions save_results: Save results to JSON file Returns: List of TestResult objects """ import random questions = load_gaia_data(level=level) questions = filter_by_type(questions, task_type) questions = filter_by_task_id(questions, task_id) if not questions: print(f"No questions found for type '{task_type}'") return [] if not run_all and n_questions and n_questions < len(questions): questions = random.sample(questions, n_questions) results: list[TestResult] = [] correct_count = 0 print(f"\n{'='*60}") print(f"Running {len(questions)} questions...") print("=" * 60) for i, q in enumerate(questions, 1): print(f"\n[{i}/{len(questions)}] Level {q.level}: {q.question[:80]}...") if q.file_name: print(f" File: {q.file_name}") if q.annotator.tools: print(f" Tools needed:\n{q.annotator.tools}") try: question_with_file = q.question if q.file_path: question_with_file += f"\n\nFile path: {q.file_path}" import asyncio actual = asyncio.run(agent(question_with_file)) except Exception as e: actual = f"ERROR: {e}" is_correct = answers_match(q.final_answer, actual) if is_correct: correct_count += 1 result = TestResult( task_id=q.task_id, question=q.question, expected=q.final_answer, actual=actual, correct=is_correct, level=q.level, file_name=q.file_name, file_path=q.file_path, annotator_steps=q.annotator.steps, annotator_tools=q.annotator.tools, annotator_number_of_steps=q.annotator.number_of_steps, ) results.append(result) status = "CORRECT" if is_correct else "WRONG" print(f" Expected: {q.final_answer}") print(f" Got: {actual[:100]}...") print(f" Status: {status}") # Summary print(f"\n{'='*60}") print("SUMMARY") print("=" * 60) print(f"Total: {len(results)}") print(f"Correct: {correct_count}") print(f"Accuracy: {correct_count/len(results)*100:.1f}%") # Per-level breakdown for lvl in [1, 2, 3]: lvl_results = [r for r in results if r.level == lvl] if lvl_results: lvl_correct = sum(1 for r in lvl_results if r.correct) print( f" Level {lvl}: {lvl_correct}/{len(lvl_results)} ({lvl_correct/len(lvl_results)*100:.1f}%)" ) # Tool usage analysis print(f"\n{'='*60}") print("TOOL REQUIREMENTS ANALYSIS") print("=" * 60) tool_stats: dict[str, dict] = {} for r in results: tools = ( r.annotator_tools if isinstance(r.annotator_tools, str) else ", ".join(r.annotator_tools) ) if tools: if tools not in tool_stats: tool_stats[tools] = {"total": 0, "correct": 0} tool_stats[tools]["total"] += 1 if r.correct: tool_stats[tools]["correct"] += 1 for tools, stats in sorted(tool_stats.items(), key=lambda x: -x[1]["total"]): acc = stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0 print(f" {tools}: {stats['correct']}/{stats['total']} ({acc:.0f}%)") # Save results if save_results: results_dir = Path("test_results") results_dir.mkdir(exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") results_file = results_dir / f"results_{timestamp}.json" with open(results_file, "w") as f: json.dump( { "timestamp": timestamp, "total": len(results), "correct": correct_count, "accuracy": correct_count / len(results), "results": [ { "task_id": r.task_id, "level": r.level, "question": r.question, "expected": r.expected, "actual": r.actual, "correct": r.correct, "file_name": r.file_name, "file_path": r.file_path, "annotator_steps": r.annotator_steps, "annotator_tools": r.annotator_tools, "annotator_number_of_steps": r.annotator_number_of_steps, } for r in results ], }, f, indent=2, ) print(f"\nResults saved to: {results_file}") return results def main(): parser = argparse.ArgumentParser(description="GAIA Test Bench") parser.add_argument("--n", type=int, default=5, help="Number of questions to test") parser.add_argument("--level", type=int, choices=[1, 2, 3], help="Filter by level") parser.add_argument( "--type", choices=[ "youtube", "file", "web", "excel", "pdf", "image", "audio", "text-only", ], help="Filter by task type", ) parser.add_argument("--all", action="store_true", help="Run all questions") parser.add_argument("--no-save", action="store_true", help="Don't save results") parser.add_argument("--task-id", type=str, help="Run specific task ID") args = parser.parse_args() from agent import BasicAgent if callable(BasicAgent) and not hasattr(BasicAgent, "__call__"): agent = BasicAgent(question="") else: agent = BasicAgent run_test_bench( agent=agent, n_questions=args.n, level=args.level, task_type=args.type, run_all=args.all, save_results=not args.no_save, task_id=args.task_id, ) if __name__ == "__main__": main()