File size: 3,833 Bytes
9fb23b8 e7b4937 9fb23b8 dc583a7 9fb23b8 dc583a7 9fb23b8 dc583a7 9fb23b8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
"""Ground truth comparison using GAIA validation dataset.
Author: @mangubee
Since the GAIA API only returns summary stats (X/Y correct) without per-question
correctness, we load the public validation dataset to compare our answers locally.
This enables per-question debugging and error analysis.
"""
import os
import logging
from typing import Dict, Optional
logger = logging.getLogger(__name__)
# ============================================================================
# CONFIG
# ============================================================================
CACHE_DIR = os.path.expanduser("~/.cache/gaia_dataset")
# ============================================================================
class GAIAGroundTruth:
"""Load GAIA validation dataset and provide ground truth answers."""
def __init__(self):
"""Initialize ground truth loader."""
self.ground_truth: Dict[str, str] = {} # task_id -> final_answer
self.metadata: Dict[str, dict] = {} # task_id -> full item data
self._loaded = False
def load_validation_set(self) -> bool:
"""Load GAIA validation dataset from HuggingFace.
Returns:
bool: True if loaded successfully, False otherwise
"""
if self._loaded:
return True
try:
from datasets import load_dataset
logger.info("Loading GAIA validation dataset...")
# Load validation set (public answers)
# Using 2023_all which includes all levels
dataset = load_dataset(
"gaia-benchmark/GAIA",
"2023_all",
split="validation",
cache_dir=CACHE_DIR
)
# Build task_id -> final_answer mapping and metadata
for item in dataset:
task_id = item.get("task_id")
final_answer = item.get("Final answer")
if task_id and final_answer:
self.ground_truth[task_id] = str(final_answer).strip()
# Store full item for metadata access
self.metadata[task_id] = dict(item)
self._loaded = True
logger.info(f"Loaded {len(self.ground_truth)} ground truth answers")
return True
except Exception as e:
logger.error(f"Failed to load GAIA dataset: {e}")
return False
def get_answer(self, task_id: str) -> Optional[str]:
"""Get ground truth answer for a task_id.
Args:
task_id: Question task ID
Returns:
Ground truth answer or None if not found
"""
if not self._loaded:
self.load_validation_set()
return self.ground_truth.get(task_id)
def compare_answer(self, task_id: str, submitted_answer: str) -> Optional[bool]:
"""Compare submitted answer against ground truth.
Args:
task_id: Question task ID
submitted_answer: Answer submitted by agent
Returns:
True if correct, False if incorrect, None if no ground truth available
"""
ground_truth = self.get_answer(task_id)
if ground_truth is None:
return None
# Normalize both answers for comparison
submitted = str(submitted_answer).strip().lower()
expected = str(ground_truth).strip().lower()
# Exact match comparison
return submitted == expected
# Singleton instance
_ground_truth_instance = None
def get_ground_truth() -> GAIAGroundTruth:
"""Get or create singleton ground truth instance.
Returns:
GAIAGroundTruth instance
"""
global _ground_truth_instance
if _ground_truth_instance is None:
_ground_truth_instance = GAIAGroundTruth()
return _ground_truth_instance
|