Spaces:
Runtime error
Runtime error
| import os | |
| from evalplus.sanitize import sanitize | |
| from typing import Any, Dict, List, Optional, Tuple, Union | |
| import itertools | |
| import multiprocessing | |
| import time | |
| from multiprocessing import Array, Value | |
| from typing import Any, Dict, List, Tuple, Union | |
| import numpy as np | |
| import pickle | |
| from evalplus.data.utils import CACHE_DIR | |
| from evalplus.eval import * | |
| from evalplus.gen.util import trusted_exec | |
| from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS, _poly | |
| from evalplus.eval.utils import TimeoutException | |
| from evalplus.eval.utils import ( | |
| create_tempdir, | |
| reliability_guard, | |
| swallow_io, | |
| time_limit, | |
| ) | |
| import re | |
| import resource | |
| import traceback | |
| SUCCESS = "success" | |
| FAILED = "failed" | |
| TIMEOUT = "timed out" | |
| _SUCCESS = 0 | |
| _FAILED = 1 | |
| _TIMEOUT = 2 | |
| _UNKNOWN = 3 | |
| _mapping = {_SUCCESS: SUCCESS, _FAILED: FAILED, _TIMEOUT: TIMEOUT, _UNKNOWN: None} | |
| class MyCustomException(BaseException): | |
| def __init__(self, message): | |
| self.message = message | |
| def get_groundtruth(problems, hashcode, tasks_only_output_not_none): | |
| cache_file = os.path.join(CACHE_DIR, f"{hashcode}.pkl") | |
| if os.path.exists(cache_file): | |
| print(f"Load from ground-truth from {cache_file}") | |
| with open(cache_file, "rb") as f: | |
| return pickle.load(f) | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| print("Computing expected output...") | |
| tbegin = time.time() | |
| expected_output = {} | |
| for task_id, problem in problems.items(): | |
| oracle = {} | |
| oracle["base"], oracle["base_time"] = trusted_exec( | |
| problem["prompt"] + problem["canonical_solution"], | |
| problem["base_input"], | |
| problem["entry_point"], | |
| record_time=True, | |
| output_not_none=problem["entry_point"] in tasks_only_output_not_none, | |
| ) | |
| oracle["plus"], oracle["plus_time"] = trusted_exec( | |
| problem["prompt"] + problem["canonical_solution"], | |
| problem["plus_input"], | |
| problem["entry_point"], | |
| record_time=True, | |
| output_not_none=problem["entry_point"] in tasks_only_output_not_none, | |
| ) | |
| expected_output[task_id] = oracle | |
| print(f"Expected outputs computed in {time.time() - tbegin:.2f}s") | |
| with open(cache_file, "wb") as f: | |
| pickle.dump(expected_output, f) | |
| return expected_output | |
| def remove_unindented_lines(code, protect_before, execeptions, trim_tails): | |
| lines = code.splitlines() | |
| cut_idx = [] | |
| cut_enabled = False | |
| for i, line in enumerate(lines): | |
| if not cut_enabled and line.startswith(protect_before): | |
| cut_enabled = True | |
| continue | |
| if line.strip() == "": | |
| continue | |
| if any(line.startswith(e) for e in execeptions): | |
| continue | |
| lspace = len(line) - len(line.lstrip()) | |
| if lspace == 0: | |
| cut_idx.append(i) | |
| if any(line.rstrip().startswith(t) for t in trim_tails): | |
| # cut off everything behind | |
| cut_idx.extend(list(range(i, len(lines)))) | |
| break | |
| return "\n".join([line for i, line in enumerate(lines) if i not in cut_idx]) | |
| def to_four_space_indents(old_code): | |
| new_code = "" | |
| for line in old_code.splitlines(): | |
| lspace = len(line) - len(line.lstrip()) | |
| if lspace == 3: | |
| new_code += " " | |
| new_code += line + "\n" | |
| return new_code | |
| def sanitize_solution(solution,eofs): | |
| task_id = solution["task_id"] | |
| dbg_identifier = task_id | |
| entry_point = solution["entry_point"] | |
| old_code = solution["solution"] | |
| old_code = old_code.strip() | |
| new_code = sanitize( | |
| old_code=old_code, | |
| entry_point=entry_point, | |
| rm_prefix_lines=None, | |
| eofs=eofs, | |
| ).strip() | |
| # if changed, print the message | |
| if new_code != old_code: | |
| msg = "Sanitized: " + dbg_identifier | |
| print(msg) | |
| solution["solution"]=new_code | |
| return solution | |
| ################################eval funcs############################ | |
| # 1st item: the status | |
| # 2nd item (optional): the detailed pass/fail boolean for each input | |
| Result = Tuple[str, List[bool]] | |
| def is_floats(x) -> bool: | |
| # check if it is float; List[float]; Tuple[float] | |
| if isinstance(x, float): | |
| return True | |
| if isinstance(x, (list, tuple)): | |
| return all(isinstance(i, float) for i in x) | |
| if isinstance(x, np.ndarray): | |
| return x.dtype == np.float64 or x.dtype == np.float32 | |
| return False | |
| def unsafe_execute( | |
| dataset: str, | |
| entry_point: str, | |
| code: str, | |
| inputs, | |
| expected: List, | |
| time_limits, | |
| atol, | |
| fast_check, | |
| stat: Value, | |
| details: Array, | |
| progress: Value, | |
| feedback: Value, | |
| feedback_size: int, | |
| ): | |
| with create_tempdir(): | |
| # These system calls are needed when cleaning up tempdir. | |
| import os | |
| import shutil | |
| rmtree = shutil.rmtree | |
| rmdir = os.rmdir | |
| chdir = os.chdir | |
| maximum_memory_bytes = None | |
| reliability_guard(maximum_memory_bytes=maximum_memory_bytes) | |
| exec_globals = {} | |
| try: | |
| with swallow_io(): | |
| exec(code, exec_globals) | |
| try: | |
| fn = exec_globals[entry_point] | |
| except KeyError as e: | |
| raise f"Please rename your function to {entry_point} as there is no function named {entry_point}." | |
| except BaseException as e: | |
| raise MyCustomException("An error occurred.") | |
| for i, inp in enumerate(inputs): | |
| # try: | |
| with time_limit(time_limits[i]): | |
| out = fn(*inp) | |
| exp = expected[i] | |
| exact_match = out == exp | |
| # ================================================ # | |
| # ============== special oracles ================= # | |
| if dataset == "mbpp": | |
| if ( | |
| "are_equivalent" == entry_point | |
| ): # Mbpp/164 special oracle | |
| exact_match = exact_match or True | |
| elif "sum_div" == entry_point: # Mbpp/295 special oracle | |
| exact_match = exact_match or out == 0 | |
| elif entry_point in MBPP_OUTPUT_NOT_NONE_TASKS: | |
| if isinstance(out, bool): | |
| exact_match = out == exp | |
| else: | |
| exact_match = exp == (out is not None) | |
| if dataset == "humaneval": | |
| if "find_zero" == entry_point: | |
| assert _poly(*out, inp) <= atol, f"The results aren't as expected.\nInput: {inp}\nExpected Output: {exp}\nActual Output: {out}" | |
| # ============== special oracles ================= # | |
| if atol == 0 and is_floats(exp): | |
| atol = 1e-6 # enforce atol for float comparison | |
| if not exact_match and atol != 0: | |
| try: | |
| np.testing.assert_allclose(out, exp, atol=atol) | |
| except BaseException as e: | |
| raise AssertionError(f"The results aren't as expected.\nInput: {inp}\nExpected Output: {exp}\nActual Output: {out}") | |
| else: | |
| assert exact_match, f"The results aren't as expected.\nInput: {inp}\nExpected Output: {exp}\nActual Output: {out}" | |
| details[i] = True | |
| progress.value += 1 | |
| stat.value = _SUCCESS | |
| padding = feedback_size - len(SUCCESS) | |
| feedback.value = (SUCCESS + " " * padding).encode('utf-8') | |
| except TimeoutException as e: | |
| stat.value = _FAILED | |
| error_str="Execution timed out." | |
| padding = max(0, feedback_size - len(error_str)) | |
| feedback.value = (error_str + " " * padding).encode('utf-8') | |
| except AssertionError as e: | |
| stat.value = _FAILED | |
| error_str=str(e)[:feedback_size] | |
| padding = max(0, feedback_size - len(error_str)) | |
| feedback.value = (error_str + " " * padding).encode('utf-8') | |
| except MyCustomException as e: | |
| stat.value = _FAILED | |
| error_str=e.message[:feedback_size] | |
| padding = max(0, feedback_size - len(error_str)) | |
| feedback.value = (error_str + " " * padding).encode('utf-8') | |
| except BaseException as e: | |
| stat.value = _FAILED | |
| error_traceback = traceback.format_exc() | |
| match = re.search(r'(File "<string>".*)', error_traceback, re.DOTALL) | |
| if match: | |
| error_traceback = match.group(1) | |
| elif "assert _poly" in error_traceback: | |
| if "TypeError: _poly() argument after *" in error_traceback: | |
| error_traceback = "TypeError: Invalid output type, output must be an iterable." | |
| else: | |
| delimiter = r'f"Input: \{inp\}\\nExpected Output: \{exp\}\\nActual Output: \{out\}"' | |
| error_traceback = re.split(delimiter, error_traceback)[-1] | |
| error_str=str(error_traceback)[:feedback_size] | |
| padding = max(0, feedback_size - len(error_str)) | |
| feedback.value = (error_str + " " * padding).encode('utf-8') | |
| # Needed for cleaning up. | |
| shutil.rmtree = rmtree | |
| os.rmdir = rmdir | |
| os.chdir = chdir | |
| def untrusted_check( | |
| dataset: str, | |
| code: str, | |
| inputs: List[Any], | |
| entry_point: str, | |
| expected, | |
| atol, | |
| ref_time: List[float], | |
| fast_check: bool = False, | |
| min_time_limit: float = 0.1, | |
| gt_time_limit_factor: float = 2.0, | |
| ) -> Tuple[str, np.ndarray]: | |
| time_limits = [max(min_time_limit, gt_time_limit_factor * t) for t in ref_time] | |
| timeout = sum(time_limits) + 1 | |
| if not fast_check: | |
| timeout += 1 # extra time for data collection | |
| # shared memory objects | |
| progress = Value("i", 0) | |
| stat = Value("i", _UNKNOWN) | |
| details = Array("b", [False for _ in range(len(inputs))]) | |
| feedback_size = 500 | |
| feedback = Array('c', b'\0' * feedback_size) | |
| p = multiprocessing.Process( | |
| target=unsafe_execute, | |
| args=( | |
| dataset, | |
| entry_point, | |
| code, | |
| inputs, | |
| expected, | |
| time_limits, | |
| atol, | |
| fast_check, | |
| stat, | |
| details, | |
| progress, | |
| feedback, | |
| feedback_size, | |
| ), | |
| ) | |
| p.start() | |
| p.join(timeout=timeout + 1) | |
| if p.is_alive(): | |
| p.terminate() | |
| time.sleep(0.1) | |
| if p.is_alive(): | |
| p.kill() | |
| time.sleep(0.1) | |
| stat = _mapping[stat.value] | |
| details = details[: progress.value] | |
| feedback = feedback.value.decode("utf-8").strip() | |
| if entry_point not in code: | |
| feedback = f"Please rename your function to {entry_point} as there is no function named {entry_point}." | |
| if not stat: | |
| stat = TIMEOUT | |
| if stat == SUCCESS: | |
| if len(details) != len(inputs) or not all(details): | |
| stat = FAILED | |
| return stat, details, feedback | |
| def check_correctness( | |
| dataset: str, | |
| completion_id: int, | |
| problem: Dict[str, Any], | |
| solution: str, | |
| expected_output: Dict[str, List], | |
| version="base", | |
| fast_check=False, | |
| identifier=None, | |
| min_time_limit: float = 0.1, | |
| gt_time_limit_factor: float = 2.0, | |
| ) -> Dict[str, Union[int, Optional[Result]]]: | |
| ret = { | |
| "completion_id": completion_id, | |
| "task_id": problem["task_id"], | |
| "_identifier": identifier, | |
| } | |
| ret["base"] = untrusted_check( | |
| dataset, | |
| solution, | |
| problem["base_input"], | |
| problem["entry_point"], | |
| expected=expected_output["base"], | |
| atol=problem["atol"], | |
| ref_time=expected_output["base_time"], | |
| fast_check=fast_check, | |
| min_time_limit=min_time_limit, | |
| gt_time_limit_factor=gt_time_limit_factor, | |
| ) | |
| if version=="plus": | |
| ret["plus"] = untrusted_check( | |
| dataset, | |
| solution, | |
| problem["plus_input"], | |
| problem["entry_point"], | |
| expected=expected_output["plus"], | |
| atol=problem["atol"], | |
| ref_time=expected_output["plus_time"], | |
| fast_check=fast_check, | |
| min_time_limit=min_time_limit, | |
| gt_time_limit_factor=gt_time_limit_factor, | |
| ) | |
| return ret |