agentbee / src /utils /ground_truth.py
mangubee's picture
fix: correct author name formatting in multiple files
e7b4937
"""Ground truth comparison using GAIA validation dataset.
Author: @mangubee
Since the GAIA API only returns summary stats (X/Y correct) without per-question
correctness, we load the public validation dataset to compare our answers locally.
This enables per-question debugging and error analysis.
"""
import os
import logging
from typing import Dict, Optional
logger = logging.getLogger(__name__)
# ============================================================================
# CONFIG
# ============================================================================
CACHE_DIR = os.path.expanduser("~/.cache/gaia_dataset")
# ============================================================================
class GAIAGroundTruth:
"""Load GAIA validation dataset and provide ground truth answers."""
def __init__(self):
"""Initialize ground truth loader."""
self.ground_truth: Dict[str, str] = {} # task_id -> final_answer
self.metadata: Dict[str, dict] = {} # task_id -> full item data
self._loaded = False
def load_validation_set(self) -> bool:
"""Load GAIA validation dataset from HuggingFace.
Returns:
bool: True if loaded successfully, False otherwise
"""
if self._loaded:
return True
try:
from datasets import load_dataset
logger.info("Loading GAIA validation dataset...")
# Load validation set (public answers)
# Using 2023_all which includes all levels
dataset = load_dataset(
"gaia-benchmark/GAIA",
"2023_all",
split="validation",
cache_dir=CACHE_DIR
)
# Build task_id -> final_answer mapping and metadata
for item in dataset:
task_id = item.get("task_id")
final_answer = item.get("Final answer")
if task_id and final_answer:
self.ground_truth[task_id] = str(final_answer).strip()
# Store full item for metadata access
self.metadata[task_id] = dict(item)
self._loaded = True
logger.info(f"Loaded {len(self.ground_truth)} ground truth answers")
return True
except Exception as e:
logger.error(f"Failed to load GAIA dataset: {e}")
return False
def get_answer(self, task_id: str) -> Optional[str]:
"""Get ground truth answer for a task_id.
Args:
task_id: Question task ID
Returns:
Ground truth answer or None if not found
"""
if not self._loaded:
self.load_validation_set()
return self.ground_truth.get(task_id)
def compare_answer(self, task_id: str, submitted_answer: str) -> Optional[bool]:
"""Compare submitted answer against ground truth.
Args:
task_id: Question task ID
submitted_answer: Answer submitted by agent
Returns:
True if correct, False if incorrect, None if no ground truth available
"""
ground_truth = self.get_answer(task_id)
if ground_truth is None:
return None
# Normalize both answers for comparison
submitted = str(submitted_answer).strip().lower()
expected = str(ground_truth).strip().lower()
# Exact match comparison
return submitted == expected
# Singleton instance
_ground_truth_instance = None
def get_ground_truth() -> GAIAGroundTruth:
"""Get or create singleton ground truth instance.
Returns:
GAIAGroundTruth instance
"""
global _ground_truth_instance
if _ground_truth_instance is None:
_ground_truth_instance = GAIAGroundTruth()
return _ground_truth_instance