File size: 3,695 Bytes
99f2cbc
 
 
 
 
 
044eaf7
99f2cbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
044eaf7
99f2cbc
 
 
b93ad3f
99f2cbc
 
044eaf7
99f2cbc
89321e2
b93ad3f
 
 
 
 
 
99f2cbc
 
 
 
 
89321e2
b93ad3f
044eaf7
 
 
 
89321e2
b93ad3f
 
 
044eaf7
99f2cbc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Base class for benchmarks."""

from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Iterator, Tuple
from dataclasses import dataclass
from pathlib import Path
import random


@dataclass
class BenchmarkDataPoint:
    """A single data point from a benchmark."""
    id: str
    text: str  # The question/prompt
    images: Optional[List[str]] = None  # List of image paths
    correct_answer: Optional[str] = None  # Ground truth answer
    metadata: Optional[Dict[str, Any]] = None  # Additional metadata


class Benchmark(ABC):
    """Abstract base class for benchmarks.
    
    This class defines the interface for all benchmarks, standardizing
    how data is loaded and accessed across different benchmark datasets.
    """

    def __init__(self, data_dir: str, **kwargs):
        """Initialize the benchmark.
        
        Args:
            data_dir (str): Directory containing benchmark data
            **kwargs: Additional configuration parameters
                random_seed (int): Random seed for shuffling data (default: None, no shuffling)
        """
        self.data_dir = Path(data_dir)
        self.config = kwargs

        self.data_points = []
        self._load_data()
        self._shuffle_data()

        self.max_questions = self.config.get("max_questions", None)
        if self.max_questions:
            self.data_points = self.data_points[:self.max_questions]
            print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
        else:
            print(f"Loaded all {len(self.data_points)} questions from {self.__class__.__name__}")

    @abstractmethod
    def _load_data(self) -> None:
        """Load benchmark data from the data directory."""
        pass

    def _shuffle_data(self) -> None:
        """Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
        
        This method is called automatically after data loading to ensure
        reproducible benchmark runs when a random seed is specified.
        """
        random_seed = self.config.get("random_seed", 42)
        random.seed(random_seed)
        random.shuffle(self.data_points)
        print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")

    def get_data_point(self, index: int) -> BenchmarkDataPoint:
        """Get a specific data point by index.
        
        Args:
            index (int): Index of the data point to retrieve
            
        Returns:
            BenchmarkDataPoint: The data point at the given index
        """
        if index < 0 or index >= len(self.data_points):
            raise IndexError(f"Index {index} out of range for {len(self.data_points)} data points")
        
        return self.data_points[index]

    def get_subset(self, indices: List[int]) -> List[BenchmarkDataPoint]:
        """Get a subset of data points by indices.
        
        Args:
            indices (List[int]): List of indices to retrieve
            
        Returns:
            List[BenchmarkDataPoint]: List of data points at the given indices
        """
        return [self.get_data_point(i) for i in indices]

    def __str__(self) -> str:
        """String representation of the benchmark."""
        return f"{self.__class__.__name__}(data_dir={self.data_dir}, size={len(self)})" 

    def __len__(self) -> int:
        """Return the number of data points in the benchmark."""
        return len(self.data_points)

    def __iter__(self) -> Iterator[BenchmarkDataPoint]:
        """Iterate over all data points in the benchmark."""
        for i in range(len(self)):
            yield self.get_data_point(i)