|
|
import ast |
|
|
import json |
|
|
import multiprocessing |
|
|
import os |
|
|
import re |
|
|
from concurrent.futures import ProcessPoolExecutor, as_completed |
|
|
from typing import Dict, List |
|
|
|
|
|
import astor |
|
|
import black |
|
|
from tqdm import tqdm |
|
|
|
|
|
from evalplus.data import ( |
|
|
get_human_eval_plus, |
|
|
get_human_eval_plus_hash, |
|
|
get_mbpp_plus, |
|
|
get_mbpp_plus_hash, |
|
|
) |
|
|
from evalplus.eval import PASS |
|
|
from evalplus.evalperf import correctness_check |
|
|
from evalplus.evaluate import get_groundtruth |
|
|
|
|
|
|
|
|
def find_calls(source, functors, skip_main_fn=True): |
|
|
class FunctorVisitor(ast.NodeVisitor): |
|
|
def __init__(self, target_functors): |
|
|
super().__init__() |
|
|
self.target_functors = target_functors |
|
|
self._temp_found_functors = set() |
|
|
self.results = [] |
|
|
|
|
|
def visit_FunctionDef(self, node): |
|
|
self.current_function = node.name |
|
|
self._temp_found_functors = set() |
|
|
self.generic_visit(node) |
|
|
if self._temp_found_functors: |
|
|
self.results.append((self.current_function, self._temp_found_functors)) |
|
|
|
|
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef): |
|
|
if skip_main_fn and node.name == "main": |
|
|
return |
|
|
|
|
|
self.generic_visit(node) |
|
|
|
|
|
def visit_Call(self, node): |
|
|
if isinstance(node.func, ast.Name) and node.func.id in self.target_functors: |
|
|
self._temp_found_functors.add(node.func.id) |
|
|
self.generic_visit(node) |
|
|
|
|
|
tree = ast.parse(source) |
|
|
visitor = FunctorVisitor(target_functors=functors) |
|
|
visitor.visit(tree) |
|
|
return {caller: callee for caller, callee in visitor.results} |
|
|
|
|
|
|
|
|
def void_calls(source, functors, skip_main_fn=True): |
|
|
changed = False |
|
|
|
|
|
class FunctorTransformer(ast.NodeTransformer): |
|
|
def __init__(self, target_functors): |
|
|
super().__init__() |
|
|
self.target_functors = target_functors |
|
|
|
|
|
|
|
|
def visit_FunctionDef(self, node: ast.FunctionDef): |
|
|
if skip_main_fn and node.name == "main": |
|
|
return node |
|
|
|
|
|
return self.generic_visit(node) |
|
|
|
|
|
def visit_Call(self, node): |
|
|
if isinstance(node.func, ast.Name) and node.func.id in self.target_functors: |
|
|
nonlocal changed |
|
|
changed = True |
|
|
return ast.Expr(value=ast.Constant(value=None)) |
|
|
return node |
|
|
|
|
|
tree = ast.parse(source) |
|
|
code = astor.to_source(FunctorTransformer(functors).visit(tree)) |
|
|
fmt_code = black.format_str(code, mode=black.FileMode()) |
|
|
return (fmt_code, changed) |
|
|
|
|
|
|
|
|
def has_print_in_non_main_functions(source) -> bool: |
|
|
for caller, _ in find_calls(source, ["print"]).items(): |
|
|
if caller != "main": |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
def gather_solutions(sample_path: str, task_id: str) -> List[str]: |
|
|
"""Gather all solutions from the folders""" |
|
|
solutions = [] |
|
|
for model in os.listdir(sample_path): |
|
|
model_path = os.path.join(sample_path, model) |
|
|
if not os.path.isdir(model_path): |
|
|
continue |
|
|
task_path = os.path.join(model_path, task_id) |
|
|
if os.path.isdir(task_path): |
|
|
for file in os.listdir(task_path): |
|
|
if file.endswith(".py"): |
|
|
with open(os.path.join(task_path, file), "r") as f: |
|
|
solutions.append(f.read()) |
|
|
return solutions |
|
|
|
|
|
|
|
|
def deduplicate(solutions: List[str]) -> List[str]: |
|
|
"""Deduplicate solutions""" |
|
|
asts = set() |
|
|
deduplicated = [] |
|
|
for solution in solutions: |
|
|
solution = re.sub(r"#[^\n]*", "", solution) |
|
|
solution = re.sub(r'"""[^"]*"""', "", solution) |
|
|
solution = re.sub(r"'''[^']*'''", "", solution) |
|
|
try: |
|
|
ast_string = ast.dump(ast.parse(solution)) |
|
|
except SyntaxError: |
|
|
continue |
|
|
except MemoryError: |
|
|
continue |
|
|
if ast_string not in asts: |
|
|
asts.add(ast_string) |
|
|
deduplicated.append(solution) |
|
|
return list(deduplicated) |
|
|
|
|
|
|
|
|
def test_solutions( |
|
|
dataset: str, solutions: List[str], task: Dict, expected_output: List |
|
|
) -> List[str]: |
|
|
"""Test solutions, return functionally correct solutions""" |
|
|
n_workers = max(1, multiprocessing.cpu_count() // 2) |
|
|
correct_solution_ids = [] |
|
|
|
|
|
with ProcessPoolExecutor(max_workers=n_workers) as executor: |
|
|
futures = [ |
|
|
executor.submit( |
|
|
correctness_check, index, solution, dataset, task, expected_output |
|
|
) |
|
|
for index, solution in enumerate(solutions) |
|
|
] |
|
|
for future in as_completed(futures): |
|
|
index, result, _ = future.result() |
|
|
if result[0] == PASS: |
|
|
correct_solution_ids.append(index) |
|
|
|
|
|
return [solutions[i] for i in correct_solution_ids] |
|
|
|
|
|
|
|
|
def script(sample_dir: str, dataset: str = "humaneval", debug_task: str = None): |
|
|
assert dataset in ["humaneval", "mbpp"] |
|
|
if dataset == "humaneval": |
|
|
problems = get_human_eval_plus(noextreme=True) |
|
|
dataset_hash = get_human_eval_plus_hash(noextreme=True) |
|
|
expected_output = get_groundtruth(problems, dataset_hash, []) |
|
|
elif dataset == "mbpp": |
|
|
from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS |
|
|
|
|
|
problems = get_mbpp_plus(noextreme=True) |
|
|
dataset_hash = get_mbpp_plus_hash(noextreme=True) |
|
|
expected_output = get_groundtruth( |
|
|
problems, |
|
|
dataset_hash, |
|
|
MBPP_OUTPUT_NOT_NONE_TASKS, |
|
|
) |
|
|
|
|
|
previous_solutions = None |
|
|
|
|
|
for task_id, task in tqdm(problems.items()): |
|
|
if debug_task and task_id != debug_task: |
|
|
continue |
|
|
solutions = gather_solutions(sample_dir, task_id.replace("/", "_")) |
|
|
solutions = deduplicate(solutions) |
|
|
correct_solutions = test_solutions( |
|
|
dataset, solutions, task, expected_output[task_id] |
|
|
) |
|
|
|
|
|
|
|
|
correct_solutions = [ |
|
|
void_calls(solution, ["print"])[0] for solution in correct_solutions |
|
|
] |
|
|
|
|
|
|
|
|
if previous_solutions and task_id in previous_solutions: |
|
|
correct_solutions = deduplicate( |
|
|
correct_solutions + previous_solutions[task_id] |
|
|
) |
|
|
with open("solutions.jsonl", "a+") as f: |
|
|
f.write( |
|
|
json.dumps({"task_id": task_id, "solution": correct_solutions}) + "\n" |
|
|
) |
|
|
|
|
|
|
|
|
def main(): |
|
|
from fire import Fire |
|
|
|
|
|
Fire(script) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|