|
|
import math |
|
|
import multiprocessing |
|
|
import time |
|
|
from typing import Any, List, Union |
|
|
|
|
|
from evalplus.data import get_human_eval_plus |
|
|
from evalplus.eval.utils import ( |
|
|
TimeoutException, |
|
|
create_tempdir, |
|
|
reliability_guard, |
|
|
swallow_io, |
|
|
time_limit, |
|
|
) |
|
|
|
|
|
MAX_WARMUP_LIMIT = 5 |
|
|
RUN_REPEAT = 25 |
|
|
|
|
|
|
|
|
def execute_for_runtime( |
|
|
code: str, inputs: List, warmups: List, entry_point: str |
|
|
) -> Union[str, float]: |
|
|
def unsafe_execute(): |
|
|
with create_tempdir(): |
|
|
|
|
|
import os |
|
|
import shutil |
|
|
|
|
|
rmtree = shutil.rmtree |
|
|
rmdir = os.rmdir |
|
|
chdir = os.chdir |
|
|
|
|
|
reliability_guard() |
|
|
|
|
|
exec_globals = {} |
|
|
exec(code, exec_globals) |
|
|
fn = exec_globals[entry_point] |
|
|
try: |
|
|
|
|
|
for warmup in warmups: |
|
|
with swallow_io(): |
|
|
fn(*warmup) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with swallow_io(): |
|
|
with time_limit(3): |
|
|
fn(*inputs) |
|
|
duration = time.time() - start_time |
|
|
|
|
|
result.append(duration) |
|
|
except TimeoutException: |
|
|
result.append("timed out") |
|
|
except BaseException as e: |
|
|
result.append("thrown exception") |
|
|
|
|
|
shutil.rmtree = rmtree |
|
|
os.rmdir = rmdir |
|
|
os.chdir = chdir |
|
|
|
|
|
manager = multiprocessing.Manager() |
|
|
result = manager.list() |
|
|
p = multiprocessing.Process(target=unsafe_execute) |
|
|
p.start() |
|
|
p.join(timeout=3 + 1) |
|
|
if p.is_alive(): |
|
|
p.kill() |
|
|
return result[0] |
|
|
|
|
|
|
|
|
def test_solution_runtime( |
|
|
dataset: str = "humaneval", |
|
|
task_id: str = "HumanEval/0", |
|
|
impl: str = "canonical", |
|
|
inputs: Union[str, List[List[Any]]] = "base_input", |
|
|
): |
|
|
if "humaneval" in dataset: |
|
|
problems, problem = get_human_eval_plus(), None |
|
|
for p in problems: |
|
|
if p["task_id"] == task_id: |
|
|
problem = p |
|
|
assert problem != None, f"invalid {task_id = }" |
|
|
entry_point = problem["entry_point"] |
|
|
impl = problem["prompt"] + ( |
|
|
impl if impl != "canonical" else problem["canonical_solution"] |
|
|
) |
|
|
if inputs == "base_input": |
|
|
inputs = problem["base_input"] |
|
|
|
|
|
results = [1000, 1000] |
|
|
for input_list in inputs: |
|
|
|
|
|
warmups = [] |
|
|
for base_input_list in problem["base_input"]: |
|
|
if ( |
|
|
hash(str(base_input_list)) != hash(str(input_list)) |
|
|
and len(warmups) < MAX_WARMUP_LIMIT |
|
|
): |
|
|
warmups.append(base_input_list) |
|
|
runtime_list = [ |
|
|
execute_for_runtime(impl, input_list, warmups, entry_point) |
|
|
for _ in range(RUN_REPEAT) |
|
|
] |
|
|
if any(type(x) != float for x in runtime_list): |
|
|
print(f"{task_id = } incorrect") |
|
|
return None, None |
|
|
|
|
|
avg_runtime = sum(runtime_list) / len(runtime_list) |
|
|
sd = math.sqrt( |
|
|
sum((runtime - avg_runtime) ** 2 for runtime in runtime_list) |
|
|
/ (RUN_REPEAT - 1) |
|
|
) |
|
|
if sd < results[1]: |
|
|
results[0] = avg_runtime |
|
|
results[1] = sd |
|
|
|
|
|
return results |
|
|
|