File size: 17,067 Bytes
5374a2d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 |
import asyncio
import random
import numpy as np
from abc import ABC, abstractmethod
from typing import Tuple, Optional, List, Dict, Any
from ..core.logging import logger
from ..core.callbacks import timeout, TimeoutException
from ..utils.sanitize import sanitize
from .lcb_utils.evaluation import estimate_pass_at_k
class Benchmark(ABC):
"""
Abstract base class for defining benchmarks. This class provides methods to load,
retrieve, and evaluate benchmark data, with train, dev, and test splits.
"""
def __init__(self, name: str, path: str, mode: str = "all", **kwargs):
"""
Initializes the benchmark with a name and data path.
Args:
name (str): The name of the benchmark.
path (str): The path to the dataset.
mode (str): which type of data to load, choices: ["all", "train", "dev", "test"]
**kwargs: Additional parameters for customization.
"""
valid_mode = ["all", "train", "dev", "test"]
assert mode in valid_mode, f"Invalid value for model: {mode}. Available choices: {valid_mode}"
self.name = name
self.path = path
self.mode = mode
self.kwargs = kwargs
# 用于存储不同数据集的内部变量
self._train_data: Optional[List[dict]] = None
self._dev_data: Optional[List[dict]] = None
self._test_data: Optional[List[dict]] = None
# load data from `self.path`
self._load_data()
@abstractmethod
def _load_data(self):
"""
Abstract method to load data from `self.path` and assign it to `_train_data`, `_dev_data`, and `_test_data` if applicable.
"""
pass
@abstractmethod
def _get_id(self, example: Any) -> Any:
"""
Abstract method to return the id for a given example.
"""
pass
@abstractmethod
def _get_label(self, example: Any) -> Any:
"""
Abstract method to return the ground-truth label for a given example.
Args:
example (Any): The input example for which the label is needed.
Returns:
Any: The ground-truth label associated with the example.
"""
pass
@abstractmethod
def evaluate(self, prediction: Any, label: Any) -> dict:
"""
Abstract method to evaluate a single prediction against the ground-truth label.
Args:
prediction (Any): The predicted output.
label (Any): The actual ground-truth label.
Returns:
dict: A dictionary containing evaluation metrics.
"""
pass
async def async_evaluate(self, prediction: Any, label: Any) -> dict:
"""
Asynchronous version of evaluate method that internally calls the synchronous evaluate.
Args:
prediction (Any): The predicted output.
label (Any): The actual ground-truth label.
Returns:
dict: A dictionary containing evaluation metrics.
"""
return await asyncio.to_thread(self.evaluate, prediction, label)
def get_label(self, example: List[Any]) -> Any:
return self._get_label(example=example)
def get_labels(self, examples: List[Any]) -> List[Any]:
return [self._get_label(example=example) for example in examples]
def get_id(self, example: List[Any]) -> Any:
return self._get_id(example=example)
def get_ids(self, examples: List[Any]) -> List[Any]:
return [self._get_id(example=example) for example in examples]
def get_data_by_mode(self, mode: str = "test") -> List[Any]:
"""
Get the data from the benchmark by mode.
"""
assert mode in ["train", "dev", "test"], f"Invalid value for mode: {mode}. Available choices: ['train', 'dev', 'test']"
if mode == "train":
if self._train_data is None:
logger.warning(f"Train data for benchmark {type(self).__name__} is not loaded or None. Return an empty list.")
return []
data = self._train_data
elif mode == "dev":
if self._dev_data is None:
logger.warning(f"Dev data for benchmark {type(self).__name__} is not loaded or None. Return an empty list.")
return []
data = self._dev_data
else:
if self._test_data is None:
logger.warning(f"Test data for benchmark {type(self).__name__} is not loaded or None. Return an empty list.")
return []
data = self._test_data
return data
def get_example_by_id(self, example_id: Any, mode: str = None) -> Optional[Any]:
"""
Get an example from the benchmark by its id.
Args:
example_id (Any): The id of the example to retrieve.
mode (str): The mode to retrieve the example from, choices: ["train", "dev", "test", "all"]
Returns:
Optional[Any]: The example if found, otherwise None.
"""
# data = self.get_data_by_mode(mode=mode)
if mode is not None and mode not in ["train", "dev", "test", "all"]:
raise ValueError(f"Invalid value for mode: {mode}. Available choices: ['train', 'dev', 'test', 'all']")
if mode is None or mode == "all":
data = []
if self._train_data is not None:
data.extend(self._train_data)
if self._dev_data is not None:
data.extend(self._dev_data)
if self._test_data is not None:
data.extend(self._test_data)
else:
data = self.get_data_by_mode(mode=mode)
for example in data:
if self._get_id(example=example) == example_id:
return example
return None
def get_example_by_index(self, index: int, mode: str = "test") -> Optional[Any]:
"""
Get an example from the benchmark by its index.
Args:
index (int): The index of the example to retrieve.
mode (str): The mode to retrieve the example from, choices: ["train", "dev", "test"]
Returns:
Optional[Any]: The example if found, otherwise None.
"""
data = self.get_data_by_mode(mode=mode)
return data[index] if index < len(data) else None
def _get_data(self, data: List[dict], indices: Optional[List[int]]=None, sample_k: Optional[int]=None, seed: Optional[int]=None) -> List[dict]:
"""
Retrieves a subset of data based on provided indices or a random sample.
Args:
data (List[dict]): The list of data examples.
indices (List[int], optional): Specific indices of data to retrieve. Defaults to None.
sample_k (int, optional): The number of random samples to retrieve. Defaults to None.
seed (int, optional): The seed for random sampling. Defaults to None. If provided, the random sampling will be deterministic.
Returns:
List[dict]: The selected subset of data. If both `indices` and `sample_k` are None, it will return the original `data`.
"""
if indices is None:
indices = list(range(len(data)))
if sample_k is not None:
if seed is not None:
random.seed(seed)
indices = random.sample(indices, k=min(sample_k, len(indices)))
return_data = [data[idx] for idx in indices]
return return_data
def get_train_data(self, indices: Optional[List[int]] = None, sample_k: Optional[int] = None, seed: Optional[int] = None) -> List[dict]:
# Retrieves training data based on specified indices or random sampling.
if self._train_data is None:
logger.warning(f"Train data for benchmark {type(self).__name__} is not loaded or None. Return an empty list.")
return []
train_data = self._get_data(self._train_data, indices=indices, sample_k=sample_k, seed=seed)
return train_data
def get_dev_data(self, indices: Optional[List[int]] = None, sample_k: Optional[int] = None, seed: Optional[int] = None) -> List[dict]:
# Retrieves development data based on specified indices or random sampling.
if self._dev_data is None:
logger.warning(f"Dev data for benchmark {type(self).__name__} is not loaded or None. Return an empty list.")
return []
dev_data = self._get_data(self._dev_data, indices=indices, sample_k=sample_k, seed=seed)
return dev_data
def get_test_data(self, indices: Optional[List[int]] = None, sample_k: Optional[int] = None, seed: Optional[int] = None) -> List[dict]:
# Retrieves test data based on specified indices or random sampling.
if self._test_data is None:
logger.warning(f"Test data for benchmark {type(self).__name__} is not loaded or None. Return an empty list.")
return []
test_data = self._get_data(self._test_data, indices=indices, sample_k=sample_k, seed=seed)
return test_data
class CodingBenchmark(Benchmark):
"""
Abstract base class for defining coding benchmarks. This class provides methods to check the solution code.
"""
def __init__(self, name: str, path: str, mode: str = "all", timeout: int = 60, **kwargs):
self.SUCCESS = 0
self.FAILED = 1
self.TIMEOUT = 2
self.timeout = timeout
super().__init__(name=name, path=path, mode=mode, **kwargs)
def handle_special_cases(self, task_id: str, solution: str, test: str) -> bool:
return solution, test
def _check_evaluation_inputs(self, prediction: Any, label: Any) -> bool:
"""
Check if the inputs are valid for evaluation.
"""
assert isinstance(prediction, str) or isinstance(prediction, list), "prediction must be a string or a list of strings, but got {}".format(type(prediction))
assert isinstance(label, dict) or isinstance(label, list), "label must be a string or a list of strings, but got {}".format(type(label))
prediction = [prediction] if isinstance(prediction, str) else prediction
label = [label] if isinstance(label, dict) else label
return prediction, label
def check_solution(self, task_id: str, solution: str, test: str, entry_point: Optional[str] = None, use_entrypoint_as_input: bool = True) -> Tuple[int, str]:
"""
Execute the solution code and check if it passes the unit test.
Args:
task_id (str): The task id.
solution (str): The solution code.
test (str): The unit test code in HumanEval format.
entry_point (str): The entry point of the solution code.
Returns:
Tuple[int, str]: A tuple containing an integer indicating whether the solution passes the unit test (0: success, 1: failed, 2: timeout) and a string containing the success/error message.
"""
solution = sanitize(solution, entrypoint=entry_point)
try:
global_dict = {
"math": __import__("math"),
"hashlib": __import__("hashlib"),
"re": __import__("re"),
"List": List,
"Dict": Dict,
"Tuple": Tuple,
"Optional": Optional,
"Any": Any,
}
solution, test = self.handle_special_cases(task_id=task_id, solution=solution, test=test)
exec(solution, global_dict)
if entry_point not in global_dict:
raise ValueError(f"Function {entry_point} not found in the solution code.")
exec(test, global_dict)
# print(global_dict)
unit_test_func = global_dict["check"] # check is the function name in the test code
# run the unit test within the timeout
with timeout(seconds=self.timeout):
if use_entrypoint_as_input:
unit_test_func(global_dict[entry_point])
else:
unit_test_func()
result = (self.SUCCESS, "The solution passed the unit test.")
except TimeoutException:
result = (self.TIMEOUT, "Execution timed out.")
except Exception as e:
error_msg = f"An error occurred: {e}\nSolution:\n{solution}\nTest:\n{test}"
result = (self.FAILED, error_msg)
return result
def check_solution_scicode(self, task_id: str, solution: str, test: str, entry_point: Optional[str] = None, use_entrypoint_as_input: bool = True) -> Tuple[int, str]:
"""
Execute the solution code and check if it passes the unit test.
Args:
task_id (str): The task id.
solution (str): The solution code.
test (str): The unit test code in HumanEval format.
entry_point (str): The entry point of the solution code.
Returns:
Tuple[int, str]: A tuple containing an integer indicating whether the solution passes the unit test (0: success, 1: failed, 2: timeout) and a string containing the success/error message.
"""
solution = sanitize(solution, entrypoint=entry_point)
try:
global_dict = {
"math": __import__("math"),
"hashlib": __import__("hashlib"),
"re": __import__("re"),
"List": List,
"Dict": Dict,
"Tuple": Tuple,
"Optional": Optional,
"Any": Any,
}
solution, test = self.handle_special_cases(task_id=task_id, solution=solution, test=test)
exec(solution, global_dict)
if entry_point not in global_dict:
raise ValueError(f"Function {entry_point} not found in the solution code.")
with timeout(seconds=self.timeout):
if use_entrypoint_as_input:
exec(test, global_dict)
else:
result = (self.FAILED, error_msg)
# print(global_dict)
result = (self.SUCCESS, "The solution passed the unit test.")
except TimeoutException:
result = (self.TIMEOUT, "Execution timed out.")
except Exception as e:
error_msg = f"An error occurred: {e}\nSolution:\n{solution}\nTest:\n{test}"
result = (self.FAILED, error_msg)
return result
def check_solution_plus(self, task_id: str, solution: str, test: str, entry_point: Optional[str] = None, use_entrypoint_as_input: bool = True) -> Tuple[int, str]:
"""
Execute the solution code and check if it passes the unit test.
Args:
task_id (str): The task id.
solution (str): The solution code.
test (str): The unit test code in HumanEval format.
entry_point (str): The entry point of the solution code.
Returns:
Tuple[int, str]: A tuple containing an integer indicating whether the solution passes the unit test (0: success, 1: failed, 2: timeout) and a string containing the success/error message.
"""
solution = sanitize(solution, entrypoint=entry_point)
result = (self.FAILED, "The code running is wrong")
try:
global_dict = {
"math": __import__("math"),
"hashlib": __import__("hashlib"),
"re": __import__("re"),
"List": List,
"Dict": Dict,
"Tuple": Tuple,
"Optional": Optional,
"Any": Any,
}
solution, test = self.handle_special_cases(task_id=task_id, solution=solution, test=test)
exec(solution, global_dict)
if entry_point not in global_dict:
raise ValueError(f"Function {entry_point} not found in the solution code.")
with timeout(seconds=self.timeout):
exec(test, global_dict)
result = (self.SUCCESS, "The solution passed the unit test.")
except TimeoutException:
result = (self.TIMEOUT, "Execution timed out.")
except Exception as e:
error_msg = f"An error occurred: {e}\nSolution:\n{solution}\nTest:\n{test}"
result = (self.FAILED, error_msg)
return result
def compute_pass_at_k(self, results: List[bool], k_list: List[int]) -> Dict[str, float]:
"""
Compute the pass@k for the given results.
"""
pass_at_k = {}
n = len(results)
c = sum(results)
for k in k_list:
if n >= k:
pass_at_k[f"pass@{k}"] = float(estimate_pass_at_k(np.array([n]), np.array([c]), k)[0])
return pass_at_k |