File size: 3,833 Bytes
9fb23b8
 
e7b4937
9fb23b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc583a7
 
9fb23b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc583a7
9fb23b8
 
 
 
 
 
dc583a7
 
9fb23b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Ground truth comparison using GAIA validation dataset.

Author: @mangubee

Since the GAIA API only returns summary stats (X/Y correct) without per-question
correctness, we load the public validation dataset to compare our answers locally.
This enables per-question debugging and error analysis.
"""

import os
import logging
from typing import Dict, Optional

logger = logging.getLogger(__name__)

# ============================================================================
# CONFIG
# ============================================================================
CACHE_DIR = os.path.expanduser("~/.cache/gaia_dataset")
# ============================================================================


class GAIAGroundTruth:
    """Load GAIA validation dataset and provide ground truth answers."""

    def __init__(self):
        """Initialize ground truth loader."""
        self.ground_truth: Dict[str, str] = {}  # task_id -> final_answer
        self.metadata: Dict[str, dict] = {}  # task_id -> full item data
        self._loaded = False

    def load_validation_set(self) -> bool:
        """Load GAIA validation dataset from HuggingFace.

        Returns:
            bool: True if loaded successfully, False otherwise
        """
        if self._loaded:
            return True

        try:
            from datasets import load_dataset

            logger.info("Loading GAIA validation dataset...")

            # Load validation set (public answers)
            # Using 2023_all which includes all levels
            dataset = load_dataset(
                "gaia-benchmark/GAIA",
                "2023_all",
                split="validation",
                cache_dir=CACHE_DIR
            )

            # Build task_id -> final_answer mapping and metadata
            for item in dataset:
                task_id = item.get("task_id")
                final_answer = item.get("Final answer")

                if task_id and final_answer:
                    self.ground_truth[task_id] = str(final_answer).strip()
                    # Store full item for metadata access
                    self.metadata[task_id] = dict(item)

            self._loaded = True
            logger.info(f"Loaded {len(self.ground_truth)} ground truth answers")
            return True

        except Exception as e:
            logger.error(f"Failed to load GAIA dataset: {e}")
            return False

    def get_answer(self, task_id: str) -> Optional[str]:
        """Get ground truth answer for a task_id.

        Args:
            task_id: Question task ID

        Returns:
            Ground truth answer or None if not found
        """
        if not self._loaded:
            self.load_validation_set()

        return self.ground_truth.get(task_id)

    def compare_answer(self, task_id: str, submitted_answer: str) -> Optional[bool]:
        """Compare submitted answer against ground truth.

        Args:
            task_id: Question task ID
            submitted_answer: Answer submitted by agent

        Returns:
            True if correct, False if incorrect, None if no ground truth available
        """
        ground_truth = self.get_answer(task_id)

        if ground_truth is None:
            return None

        # Normalize both answers for comparison
        submitted = str(submitted_answer).strip().lower()
        expected = str(ground_truth).strip().lower()

        # Exact match comparison
        return submitted == expected


# Singleton instance
_ground_truth_instance = None


def get_ground_truth() -> GAIAGroundTruth:
    """Get or create singleton ground truth instance.

    Returns:
        GAIAGroundTruth instance
    """
    global _ground_truth_instance

    if _ground_truth_instance is None:
        _ground_truth_instance = GAIAGroundTruth()

    return _ground_truth_instance