"""Ground truth comparison using GAIA validation dataset. Author: @mangubee Since the GAIA API only returns summary stats (X/Y correct) without per-question correctness, we load the public validation dataset to compare our answers locally. This enables per-question debugging and error analysis. """ import os import logging from typing import Dict, Optional logger = logging.getLogger(__name__) # ============================================================================ # CONFIG # ============================================================================ CACHE_DIR = os.path.expanduser("~/.cache/gaia_dataset") # ============================================================================ class GAIAGroundTruth: """Load GAIA validation dataset and provide ground truth answers.""" def __init__(self): """Initialize ground truth loader.""" self.ground_truth: Dict[str, str] = {} # task_id -> final_answer self.metadata: Dict[str, dict] = {} # task_id -> full item data self._loaded = False def load_validation_set(self) -> bool: """Load GAIA validation dataset from HuggingFace. Returns: bool: True if loaded successfully, False otherwise """ if self._loaded: return True try: from datasets import load_dataset logger.info("Loading GAIA validation dataset...") # Load validation set (public answers) # Using 2023_all which includes all levels dataset = load_dataset( "gaia-benchmark/GAIA", "2023_all", split="validation", cache_dir=CACHE_DIR ) # Build task_id -> final_answer mapping and metadata for item in dataset: task_id = item.get("task_id") final_answer = item.get("Final answer") if task_id and final_answer: self.ground_truth[task_id] = str(final_answer).strip() # Store full item for metadata access self.metadata[task_id] = dict(item) self._loaded = True logger.info(f"Loaded {len(self.ground_truth)} ground truth answers") return True except Exception as e: logger.error(f"Failed to load GAIA dataset: {e}") return False def get_answer(self, task_id: str) -> Optional[str]: """Get ground truth answer for a task_id. Args: task_id: Question task ID Returns: Ground truth answer or None if not found """ if not self._loaded: self.load_validation_set() return self.ground_truth.get(task_id) def compare_answer(self, task_id: str, submitted_answer: str) -> Optional[bool]: """Compare submitted answer against ground truth. Args: task_id: Question task ID submitted_answer: Answer submitted by agent Returns: True if correct, False if incorrect, None if no ground truth available """ ground_truth = self.get_answer(task_id) if ground_truth is None: return None # Normalize both answers for comparison submitted = str(submitted_answer).strip().lower() expected = str(ground_truth).strip().lower() # Exact match comparison return submitted == expected # Singleton instance _ground_truth_instance = None def get_ground_truth() -> GAIAGroundTruth: """Get or create singleton ground truth instance. Returns: GAIAGroundTruth instance """ global _ground_truth_instance if _ground_truth_instance is None: _ground_truth_instance = GAIAGroundTruth() return _ground_truth_instance