Spaces:
Sleeping
Sleeping
File size: 3,695 Bytes
99f2cbc 044eaf7 99f2cbc 044eaf7 99f2cbc b93ad3f 99f2cbc 044eaf7 99f2cbc 89321e2 b93ad3f 99f2cbc 89321e2 b93ad3f 044eaf7 89321e2 b93ad3f 044eaf7 99f2cbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
"""Base class for benchmarks."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Any, Iterator, Tuple
from dataclasses import dataclass
from pathlib import Path
import random
@dataclass
class BenchmarkDataPoint:
"""A single data point from a benchmark."""
id: str
text: str # The question/prompt
images: Optional[List[str]] = None # List of image paths
correct_answer: Optional[str] = None # Ground truth answer
metadata: Optional[Dict[str, Any]] = None # Additional metadata
class Benchmark(ABC):
"""Abstract base class for benchmarks.
This class defines the interface for all benchmarks, standardizing
how data is loaded and accessed across different benchmark datasets.
"""
def __init__(self, data_dir: str, **kwargs):
"""Initialize the benchmark.
Args:
data_dir (str): Directory containing benchmark data
**kwargs: Additional configuration parameters
random_seed (int): Random seed for shuffling data (default: None, no shuffling)
"""
self.data_dir = Path(data_dir)
self.config = kwargs
self.data_points = []
self._load_data()
self._shuffle_data()
self.max_questions = self.config.get("max_questions", None)
if self.max_questions:
self.data_points = self.data_points[:self.max_questions]
print(f"Randomly sampled {self.max_questions} questions from {self.__class__.__name__}")
else:
print(f"Loaded all {len(self.data_points)} questions from {self.__class__.__name__}")
@abstractmethod
def _load_data(self) -> None:
"""Load benchmark data from the data directory."""
pass
def _shuffle_data(self) -> None:
"""Shuffle the data points if a random seed is provided. If no random seed is provided, use 42 as default.
This method is called automatically after data loading to ensure
reproducible benchmark runs when a random seed is specified.
"""
random_seed = self.config.get("random_seed", 42)
random.seed(random_seed)
random.shuffle(self.data_points)
print(f"Shuffled {len(self.data_points)} data points with seed {random_seed}")
def get_data_point(self, index: int) -> BenchmarkDataPoint:
"""Get a specific data point by index.
Args:
index (int): Index of the data point to retrieve
Returns:
BenchmarkDataPoint: The data point at the given index
"""
if index < 0 or index >= len(self.data_points):
raise IndexError(f"Index {index} out of range for {len(self.data_points)} data points")
return self.data_points[index]
def get_subset(self, indices: List[int]) -> List[BenchmarkDataPoint]:
"""Get a subset of data points by indices.
Args:
indices (List[int]): List of indices to retrieve
Returns:
List[BenchmarkDataPoint]: List of data points at the given indices
"""
return [self.get_data_point(i) for i in indices]
def __str__(self) -> str:
"""String representation of the benchmark."""
return f"{self.__class__.__name__}(data_dir={self.data_dir}, size={len(self)})"
def __len__(self) -> int:
"""Return the number of data points in the benchmark."""
return len(self.data_points)
def __iter__(self) -> Iterator[BenchmarkDataPoint]:
"""Iterate over all data points in the benchmark."""
for i in range(len(self)):
yield self.get_data_point(i)
|