import ast import json import multiprocessing import os import re from concurrent.futures import ProcessPoolExecutor, as_completed from typing import Dict, List import astor import black from tqdm import tqdm from evalplus.data import ( get_human_eval_plus, get_human_eval_plus_hash, get_mbpp_plus, get_mbpp_plus_hash, ) from evalplus.eval import PASS from evalplus.evalperf import correctness_check from evalplus.evaluate import get_groundtruth def find_calls(source, functors, skip_main_fn=True): class FunctorVisitor(ast.NodeVisitor): def __init__(self, target_functors): super().__init__() self.target_functors = target_functors self._temp_found_functors = set() self.results = [] def visit_FunctionDef(self, node): self.current_function = node.name self._temp_found_functors = set() self.generic_visit(node) # Continue traversing child nodes if self._temp_found_functors: self.results.append((self.current_function, self._temp_found_functors)) # visit func def def visit_FunctionDef(self, node: ast.FunctionDef): if skip_main_fn and node.name == "main": return # visit child nodes self.generic_visit(node) def visit_Call(self, node): if isinstance(node.func, ast.Name) and node.func.id in self.target_functors: self._temp_found_functors.add(node.func.id) self.generic_visit(node) tree = ast.parse(source) visitor = FunctorVisitor(target_functors=functors) visitor.visit(tree) return {caller: callee for caller, callee in visitor.results} def void_calls(source, functors, skip_main_fn=True): changed = False class FunctorTransformer(ast.NodeTransformer): def __init__(self, target_functors): super().__init__() self.target_functors = target_functors # visit func def def visit_FunctionDef(self, node: ast.FunctionDef): if skip_main_fn and node.name == "main": return node # visit child nodes return self.generic_visit(node) def visit_Call(self, node): if isinstance(node.func, ast.Name) and node.func.id in self.target_functors: nonlocal changed changed = True return ast.Expr(value=ast.Constant(value=None)) return node tree = ast.parse(source) code = astor.to_source(FunctorTransformer(functors).visit(tree)) fmt_code = black.format_str(code, mode=black.FileMode()) return (fmt_code, changed) def has_print_in_non_main_functions(source) -> bool: for caller, _ in find_calls(source, ["print"]).items(): if caller != "main": return True return False def gather_solutions(sample_path: str, task_id: str) -> List[str]: """Gather all solutions from the folders""" solutions = [] for model in os.listdir(sample_path): model_path = os.path.join(sample_path, model) if not os.path.isdir(model_path): continue task_path = os.path.join(model_path, task_id) if os.path.isdir(task_path): for file in os.listdir(task_path): if file.endswith(".py"): with open(os.path.join(task_path, file), "r") as f: solutions.append(f.read()) return solutions def deduplicate(solutions: List[str]) -> List[str]: """Deduplicate solutions""" asts = set() deduplicated = [] for solution in solutions: solution = re.sub(r"#[^\n]*", "", solution) solution = re.sub(r'"""[^"]*"""', "", solution) solution = re.sub(r"'''[^']*'''", "", solution) try: ast_string = ast.dump(ast.parse(solution)) except SyntaxError: continue except MemoryError: continue if ast_string not in asts: asts.add(ast_string) deduplicated.append(solution) return list(deduplicated) def test_solutions( dataset: str, solutions: List[str], task: Dict, expected_output: List ) -> List[str]: """Test solutions, return functionally correct solutions""" n_workers = max(1, multiprocessing.cpu_count() // 2) correct_solution_ids = [] with ProcessPoolExecutor(max_workers=n_workers) as executor: futures = [ executor.submit( correctness_check, index, solution, dataset, task, expected_output ) for index, solution in enumerate(solutions) ] for future in as_completed(futures): index, result, _ = future.result() if result[0] == PASS: correct_solution_ids.append(index) return [solutions[i] for i in correct_solution_ids] def script(sample_dir: str, dataset: str = "humaneval", debug_task: str = None): assert dataset in ["humaneval", "mbpp"] if dataset == "humaneval": problems = get_human_eval_plus(noextreme=True) dataset_hash = get_human_eval_plus_hash(noextreme=True) expected_output = get_groundtruth(problems, dataset_hash, []) elif dataset == "mbpp": from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS problems = get_mbpp_plus(noextreme=True) dataset_hash = get_mbpp_plus_hash(noextreme=True) expected_output = get_groundtruth( problems, dataset_hash, MBPP_OUTPUT_NOT_NONE_TASKS, ) previous_solutions = None for task_id, task in tqdm(problems.items()): if debug_task and task_id != debug_task: continue solutions = gather_solutions(sample_dir, task_id.replace("/", "_")) solutions = deduplicate(solutions) correct_solutions = test_solutions( dataset, solutions, task, expected_output[task_id] ) # clean solutions to remove print statements and format it correct_solutions = [ void_calls(solution, ["print"])[0] for solution in correct_solutions ] # Assuming that the previous solutions are correct if previous_solutions and task_id in previous_solutions: correct_solutions = deduplicate( correct_solutions + previous_solutions[task_id] ) with open("solutions.jsonl", "a+") as f: f.write( json.dumps({"task_id": task_id, "solution": correct_solutions}) + "\n" ) def main(): from fire import Fire Fire(script) if __name__ == "__main__": main()