|
|
""" |
|
|
GAIA Test Bench - Local evaluation for your agent |
|
|
|
|
|
Usage: |
|
|
uv run python test_bench.py # Run on 5 random questions |
|
|
uv run python test_bench.py --n 10 # Run on 10 random questions |
|
|
uv run python test_bench.py --level 1 # Run on level 1 only |
|
|
uv run python test_bench.py --level 1 --n 3 # Run on 3 level 1 questions |
|
|
uv run python test_bench.py --all # Run on all validation questions |
|
|
uv run python test_bench.py --task-id 1234 # Run on specific task ID |
|
|
|
|
|
uv run python test_bench.py --type youtube,file,web,excel,pdf,image,audio,text-only |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from datetime import datetime |
|
|
from pathlib import Path |
|
|
|
|
|
from datasets import load_dataset |
|
|
from dotenv import load_dotenv |
|
|
from huggingface_hub import snapshot_download |
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
class AnnotatorMetadata(BaseModel): |
|
|
"""Metadata from GAIA annotators about how to solve the task.""" |
|
|
|
|
|
steps: str = Field(default="", alias="Steps") |
|
|
number_of_steps: str = Field(default="", alias="Number of steps") |
|
|
tools: str = Field(default="", alias="Tools") |
|
|
number_of_tools: str = Field(default="", alias="Number of tools") |
|
|
how_long: str = Field(default="", alias="How long did this take?") |
|
|
|
|
|
|
|
|
class GaiaQuestion(BaseModel): |
|
|
"""A single question from the GAIA dataset.""" |
|
|
|
|
|
model_config = {"populate_by_name": True} |
|
|
|
|
|
task_id: str |
|
|
question: str = Field(alias="Question") |
|
|
level: int = Field(alias="Level") |
|
|
final_answer: str = Field(default="", alias="Final answer") |
|
|
file_name: str = "" |
|
|
file_path: str | None = None |
|
|
annotator: AnnotatorMetadata = Field( |
|
|
default_factory=AnnotatorMetadata, alias="Annotator Metadata" |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TestResult: |
|
|
"""Result of running an agent on a GAIA question.""" |
|
|
|
|
|
task_id: str |
|
|
question: str |
|
|
expected: str |
|
|
actual: str |
|
|
correct: bool |
|
|
level: int |
|
|
file_name: str = "" |
|
|
file_path: str | None = None |
|
|
annotator_steps: str = "" |
|
|
annotator_tools: str = "" |
|
|
annotator_number_of_steps: str = "" |
|
|
|
|
|
|
|
|
def normalize_answer(answer: str) -> str: |
|
|
"""Normalize answer for comparison.""" |
|
|
if not answer: |
|
|
return "" |
|
|
ans = answer.lower().strip() |
|
|
ans = re.sub(r"[.,;:!?]+$", "", ans) |
|
|
return ans |
|
|
|
|
|
|
|
|
def answers_match(expected: str, actual: str) -> bool: |
|
|
"""Check if answers match (with some flexibility).""" |
|
|
exp = normalize_answer(expected) |
|
|
act = normalize_answer(actual) |
|
|
|
|
|
if exp == act: |
|
|
return True |
|
|
|
|
|
if exp in act: |
|
|
return True |
|
|
|
|
|
try: |
|
|
exp_num = float(re.sub(r"[^\d.-]", "", exp)) |
|
|
act_num = float(re.sub(r"[^\d.-]", "", act)) |
|
|
if abs(exp_num - act_num) < 0.01: |
|
|
return True |
|
|
except (ValueError, TypeError): |
|
|
pass |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
def filter_by_type( |
|
|
questions: list[GaiaQuestion], task_type: str | None |
|
|
) -> list[GaiaQuestion]: |
|
|
"""Filter questions by task type.""" |
|
|
if not task_type: |
|
|
return questions |
|
|
|
|
|
filtered = [] |
|
|
for q in questions: |
|
|
question_text = q.question.lower() |
|
|
file_name = q.file_name.lower() |
|
|
|
|
|
match task_type: |
|
|
case "youtube": |
|
|
if "youtube.com" in question_text or "youtu.be" in question_text: |
|
|
filtered.append(q) |
|
|
case "file": |
|
|
if q.file_name: |
|
|
filtered.append(q) |
|
|
case "web": |
|
|
if "http://" in question_text or "https://" in question_text: |
|
|
filtered.append(q) |
|
|
case "excel": |
|
|
if file_name.endswith((".xlsx", ".xls", ".csv")): |
|
|
filtered.append(q) |
|
|
case "pdf": |
|
|
if file_name.endswith(".pdf"): |
|
|
filtered.append(q) |
|
|
case "image": |
|
|
if file_name.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")): |
|
|
filtered.append(q) |
|
|
case "audio": |
|
|
if file_name.endswith((".mp3", ".wav", ".m4a", ".flac", ".ogg")): |
|
|
filtered.append(q) |
|
|
case "text-only": |
|
|
has_url = "http://" in question_text or "https://" in question_text |
|
|
has_file = bool(q.file_name) |
|
|
if not has_url and not has_file: |
|
|
filtered.append(q) |
|
|
case _: |
|
|
filtered.append(q) |
|
|
|
|
|
return filtered |
|
|
|
|
|
|
|
|
def filter_by_task_id( |
|
|
questions: list[GaiaQuestion], task_id: str | None |
|
|
) -> list[GaiaQuestion]: |
|
|
"""Filter questions by task id.""" |
|
|
if not task_id: |
|
|
return questions |
|
|
|
|
|
filtered = [] |
|
|
for q in questions: |
|
|
if q.task_id == task_id: |
|
|
filtered.append(q) |
|
|
|
|
|
return filtered |
|
|
|
|
|
|
|
|
def load_gaia_data(level: int | None = None) -> list[GaiaQuestion]: |
|
|
"""Load GAIA validation dataset with all metadata.""" |
|
|
print("Downloading GAIA dataset...") |
|
|
data_dir = snapshot_download(repo_id="gaia-benchmark/GAIA", repo_type="dataset") |
|
|
|
|
|
questions: list[GaiaQuestion] = [] |
|
|
|
|
|
levels = [level] if level else [1, 2, 3] |
|
|
for lvl in levels: |
|
|
try: |
|
|
dataset = load_dataset(data_dir, f"2023_level{lvl}", split="validation") |
|
|
|
|
|
for example in dataset: |
|
|
question = GaiaQuestion.model_validate(example) |
|
|
|
|
|
if question.file_path: |
|
|
question.file_path = os.path.join(data_dir, question.file_path) |
|
|
questions.append(question) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load level {lvl}: {e}") |
|
|
|
|
|
print(f"Loaded {len(questions)} questions") |
|
|
return questions |
|
|
|
|
|
|
|
|
def run_test_bench( |
|
|
agent, |
|
|
n_questions: int | None = 5, |
|
|
level: int | None = None, |
|
|
task_type: str | None = None, |
|
|
run_all: bool = False, |
|
|
save_results: bool = True, |
|
|
task_id: str | None = None, |
|
|
) -> list[TestResult]: |
|
|
""" |
|
|
Run the test bench on the agent. |
|
|
|
|
|
Args: |
|
|
agent: The agent to test (callable that takes question string) |
|
|
n_questions: Number of questions to test (None for all) |
|
|
level: Filter by difficulty level (1, 2, or 3) |
|
|
task_type: Filter by task type (youtube, file, web, excel, pdf, image, audio, text-only) |
|
|
run_all: If True, run on all questions |
|
|
save_results: Save results to JSON file |
|
|
|
|
|
Returns: |
|
|
List of TestResult objects |
|
|
""" |
|
|
import random |
|
|
|
|
|
questions = load_gaia_data(level=level) |
|
|
questions = filter_by_type(questions, task_type) |
|
|
questions = filter_by_task_id(questions, task_id) |
|
|
|
|
|
if not questions: |
|
|
print(f"No questions found for type '{task_type}'") |
|
|
return [] |
|
|
|
|
|
if not run_all and n_questions and n_questions < len(questions): |
|
|
questions = random.sample(questions, n_questions) |
|
|
|
|
|
results: list[TestResult] = [] |
|
|
correct_count = 0 |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Running {len(questions)} questions...") |
|
|
print("=" * 60) |
|
|
|
|
|
for i, q in enumerate(questions, 1): |
|
|
print(f"\n[{i}/{len(questions)}] Level {q.level}: {q.question[:80]}...") |
|
|
|
|
|
if q.file_name: |
|
|
print(f" File: {q.file_name}") |
|
|
if q.annotator.tools: |
|
|
print(f" Tools needed:\n{q.annotator.tools}") |
|
|
|
|
|
try: |
|
|
question_with_file = q.question |
|
|
if q.file_path: |
|
|
question_with_file += f"\n\nFile path: {q.file_path}" |
|
|
import asyncio |
|
|
|
|
|
actual = asyncio.run(agent(question_with_file)) |
|
|
except Exception as e: |
|
|
actual = f"ERROR: {e}" |
|
|
|
|
|
is_correct = answers_match(q.final_answer, actual) |
|
|
if is_correct: |
|
|
correct_count += 1 |
|
|
|
|
|
result = TestResult( |
|
|
task_id=q.task_id, |
|
|
question=q.question, |
|
|
expected=q.final_answer, |
|
|
actual=actual, |
|
|
correct=is_correct, |
|
|
level=q.level, |
|
|
file_name=q.file_name, |
|
|
file_path=q.file_path, |
|
|
annotator_steps=q.annotator.steps, |
|
|
annotator_tools=q.annotator.tools, |
|
|
annotator_number_of_steps=q.annotator.number_of_steps, |
|
|
) |
|
|
results.append(result) |
|
|
|
|
|
status = "CORRECT" if is_correct else "WRONG" |
|
|
print(f" Expected: {q.final_answer}") |
|
|
print(f" Got: {actual[:100]}...") |
|
|
print(f" Status: {status}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("SUMMARY") |
|
|
print("=" * 60) |
|
|
print(f"Total: {len(results)}") |
|
|
print(f"Correct: {correct_count}") |
|
|
print(f"Accuracy: {correct_count/len(results)*100:.1f}%") |
|
|
|
|
|
|
|
|
for lvl in [1, 2, 3]: |
|
|
lvl_results = [r for r in results if r.level == lvl] |
|
|
if lvl_results: |
|
|
lvl_correct = sum(1 for r in lvl_results if r.correct) |
|
|
print( |
|
|
f" Level {lvl}: {lvl_correct}/{len(lvl_results)} ({lvl_correct/len(lvl_results)*100:.1f}%)" |
|
|
) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("TOOL REQUIREMENTS ANALYSIS") |
|
|
print("=" * 60) |
|
|
tool_stats: dict[str, dict] = {} |
|
|
for r in results: |
|
|
tools = ( |
|
|
r.annotator_tools |
|
|
if isinstance(r.annotator_tools, str) |
|
|
else ", ".join(r.annotator_tools) |
|
|
) |
|
|
if tools: |
|
|
if tools not in tool_stats: |
|
|
tool_stats[tools] = {"total": 0, "correct": 0} |
|
|
tool_stats[tools]["total"] += 1 |
|
|
if r.correct: |
|
|
tool_stats[tools]["correct"] += 1 |
|
|
|
|
|
for tools, stats in sorted(tool_stats.items(), key=lambda x: -x[1]["total"]): |
|
|
acc = stats["correct"] / stats["total"] * 100 if stats["total"] > 0 else 0 |
|
|
print(f" {tools}: {stats['correct']}/{stats['total']} ({acc:.0f}%)") |
|
|
|
|
|
|
|
|
if save_results: |
|
|
results_dir = Path("test_results") |
|
|
results_dir.mkdir(exist_ok=True) |
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
results_file = results_dir / f"results_{timestamp}.json" |
|
|
|
|
|
with open(results_file, "w") as f: |
|
|
json.dump( |
|
|
{ |
|
|
"timestamp": timestamp, |
|
|
"total": len(results), |
|
|
"correct": correct_count, |
|
|
"accuracy": correct_count / len(results), |
|
|
"results": [ |
|
|
{ |
|
|
"task_id": r.task_id, |
|
|
"level": r.level, |
|
|
"question": r.question, |
|
|
"expected": r.expected, |
|
|
"actual": r.actual, |
|
|
"correct": r.correct, |
|
|
"file_name": r.file_name, |
|
|
"file_path": r.file_path, |
|
|
"annotator_steps": r.annotator_steps, |
|
|
"annotator_tools": r.annotator_tools, |
|
|
"annotator_number_of_steps": r.annotator_number_of_steps, |
|
|
} |
|
|
for r in results |
|
|
], |
|
|
}, |
|
|
f, |
|
|
indent=2, |
|
|
) |
|
|
print(f"\nResults saved to: {results_file}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="GAIA Test Bench") |
|
|
parser.add_argument("--n", type=int, default=5, help="Number of questions to test") |
|
|
parser.add_argument("--level", type=int, choices=[1, 2, 3], help="Filter by level") |
|
|
parser.add_argument( |
|
|
"--type", |
|
|
choices=[ |
|
|
"youtube", |
|
|
"file", |
|
|
"web", |
|
|
"excel", |
|
|
"pdf", |
|
|
"image", |
|
|
"audio", |
|
|
"text-only", |
|
|
], |
|
|
help="Filter by task type", |
|
|
) |
|
|
parser.add_argument("--all", action="store_true", help="Run all questions") |
|
|
parser.add_argument("--no-save", action="store_true", help="Don't save results") |
|
|
parser.add_argument("--task-id", type=str, help="Run specific task ID") |
|
|
args = parser.parse_args() |
|
|
|
|
|
from agent import BasicAgent |
|
|
|
|
|
if callable(BasicAgent) and not hasattr(BasicAgent, "__call__"): |
|
|
agent = BasicAgent(question="") |
|
|
else: |
|
|
agent = BasicAgent |
|
|
|
|
|
run_test_bench( |
|
|
agent=agent, |
|
|
n_questions=args.n, |
|
|
level=args.level, |
|
|
task_type=args.type, |
|
|
run_all=args.all, |
|
|
save_results=not args.no_save, |
|
|
task_id=args.task_id, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|