hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
import ast
import json
import multiprocessing
import os
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List
import astor
import black
from tqdm import tqdm
from evalplus.data import (
get_human_eval_plus,
get_human_eval_plus_hash,
get_mbpp_plus,
get_mbpp_plus_hash,
)
from evalplus.eval import PASS
from evalplus.evalperf import correctness_check
from evalplus.evaluate import get_groundtruth
def find_calls(source, functors, skip_main_fn=True):
class FunctorVisitor(ast.NodeVisitor):
def __init__(self, target_functors):
super().__init__()
self.target_functors = target_functors
self._temp_found_functors = set()
self.results = []
def visit_FunctionDef(self, node):
self.current_function = node.name
self._temp_found_functors = set()
self.generic_visit(node) # Continue traversing child nodes
if self._temp_found_functors:
self.results.append((self.current_function, self._temp_found_functors))
# visit func def
def visit_FunctionDef(self, node: ast.FunctionDef):
if skip_main_fn and node.name == "main":
return
# visit child nodes
self.generic_visit(node)
def visit_Call(self, node):
if isinstance(node.func, ast.Name) and node.func.id in self.target_functors:
self._temp_found_functors.add(node.func.id)
self.generic_visit(node)
tree = ast.parse(source)
visitor = FunctorVisitor(target_functors=functors)
visitor.visit(tree)
return {caller: callee for caller, callee in visitor.results}
def void_calls(source, functors, skip_main_fn=True):
changed = False
class FunctorTransformer(ast.NodeTransformer):
def __init__(self, target_functors):
super().__init__()
self.target_functors = target_functors
# visit func def
def visit_FunctionDef(self, node: ast.FunctionDef):
if skip_main_fn and node.name == "main":
return node
# visit child nodes
return self.generic_visit(node)
def visit_Call(self, node):
if isinstance(node.func, ast.Name) and node.func.id in self.target_functors:
nonlocal changed
changed = True
return ast.Expr(value=ast.Constant(value=None))
return node
tree = ast.parse(source)
code = astor.to_source(FunctorTransformer(functors).visit(tree))
fmt_code = black.format_str(code, mode=black.FileMode())
return (fmt_code, changed)
def has_print_in_non_main_functions(source) -> bool:
for caller, _ in find_calls(source, ["print"]).items():
if caller != "main":
return True
return False
def gather_solutions(sample_path: str, task_id: str) -> List[str]:
"""Gather all solutions from the folders"""
solutions = []
for model in os.listdir(sample_path):
model_path = os.path.join(sample_path, model)
if not os.path.isdir(model_path):
continue
task_path = os.path.join(model_path, task_id)
if os.path.isdir(task_path):
for file in os.listdir(task_path):
if file.endswith(".py"):
with open(os.path.join(task_path, file), "r") as f:
solutions.append(f.read())
return solutions
def deduplicate(solutions: List[str]) -> List[str]:
"""Deduplicate solutions"""
asts = set()
deduplicated = []
for solution in solutions:
solution = re.sub(r"#[^\n]*", "", solution)
solution = re.sub(r'"""[^"]*"""', "", solution)
solution = re.sub(r"'''[^']*'''", "", solution)
try:
ast_string = ast.dump(ast.parse(solution))
except SyntaxError:
continue
except MemoryError:
continue
if ast_string not in asts:
asts.add(ast_string)
deduplicated.append(solution)
return list(deduplicated)
def test_solutions(
dataset: str, solutions: List[str], task: Dict, expected_output: List
) -> List[str]:
"""Test solutions, return functionally correct solutions"""
n_workers = max(1, multiprocessing.cpu_count() // 2)
correct_solution_ids = []
with ProcessPoolExecutor(max_workers=n_workers) as executor:
futures = [
executor.submit(
correctness_check, index, solution, dataset, task, expected_output
)
for index, solution in enumerate(solutions)
]
for future in as_completed(futures):
index, result, _ = future.result()
if result[0] == PASS:
correct_solution_ids.append(index)
return [solutions[i] for i in correct_solution_ids]
def script(sample_dir: str, dataset: str = "humaneval", debug_task: str = None):
assert dataset in ["humaneval", "mbpp"]
if dataset == "humaneval":
problems = get_human_eval_plus(noextreme=True)
dataset_hash = get_human_eval_plus_hash(noextreme=True)
expected_output = get_groundtruth(problems, dataset_hash, [])
elif dataset == "mbpp":
from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS
problems = get_mbpp_plus(noextreme=True)
dataset_hash = get_mbpp_plus_hash(noextreme=True)
expected_output = get_groundtruth(
problems,
dataset_hash,
MBPP_OUTPUT_NOT_NONE_TASKS,
)
previous_solutions = None
for task_id, task in tqdm(problems.items()):
if debug_task and task_id != debug_task:
continue
solutions = gather_solutions(sample_dir, task_id.replace("/", "_"))
solutions = deduplicate(solutions)
correct_solutions = test_solutions(
dataset, solutions, task, expected_output[task_id]
)
# clean solutions to remove print statements and format it
correct_solutions = [
void_calls(solution, ["print"])[0] for solution in correct_solutions
]
# Assuming that the previous solutions are correct
if previous_solutions and task_id in previous_solutions:
correct_solutions = deduplicate(
correct_solutions + previous_solutions[task_id]
)
with open("solutions.jsonl", "a+") as f:
f.write(
json.dumps({"task_id": task_id, "solution": correct_solutions}) + "\n"
)
def main():
from fire import Fire
Fire(script)
if __name__ == "__main__":
main()