File size: 6,696 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import ast
import json
import multiprocessing
import os
import re
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Dict, List

import astor
import black
from tqdm import tqdm

from evalplus.data import (
    get_human_eval_plus,
    get_human_eval_plus_hash,
    get_mbpp_plus,
    get_mbpp_plus_hash,
)
from evalplus.eval import PASS
from evalplus.evalperf import correctness_check
from evalplus.evaluate import get_groundtruth


def find_calls(source, functors, skip_main_fn=True):
    class FunctorVisitor(ast.NodeVisitor):
        def __init__(self, target_functors):
            super().__init__()
            self.target_functors = target_functors
            self._temp_found_functors = set()
            self.results = []

        def visit_FunctionDef(self, node):
            self.current_function = node.name
            self._temp_found_functors = set()
            self.generic_visit(node)  # Continue traversing child nodes
            if self._temp_found_functors:
                self.results.append((self.current_function, self._temp_found_functors))

        # visit func def
        def visit_FunctionDef(self, node: ast.FunctionDef):
            if skip_main_fn and node.name == "main":
                return
            # visit child nodes
            self.generic_visit(node)

        def visit_Call(self, node):
            if isinstance(node.func, ast.Name) and node.func.id in self.target_functors:
                self._temp_found_functors.add(node.func.id)
            self.generic_visit(node)

    tree = ast.parse(source)
    visitor = FunctorVisitor(target_functors=functors)
    visitor.visit(tree)
    return {caller: callee for caller, callee in visitor.results}


def void_calls(source, functors, skip_main_fn=True):
    changed = False

    class FunctorTransformer(ast.NodeTransformer):
        def __init__(self, target_functors):
            super().__init__()
            self.target_functors = target_functors

        # visit func def
        def visit_FunctionDef(self, node: ast.FunctionDef):
            if skip_main_fn and node.name == "main":
                return node
            # visit child nodes
            return self.generic_visit(node)

        def visit_Call(self, node):
            if isinstance(node.func, ast.Name) and node.func.id in self.target_functors:
                nonlocal changed
                changed = True
                return ast.Expr(value=ast.Constant(value=None))
            return node

    tree = ast.parse(source)
    code = astor.to_source(FunctorTransformer(functors).visit(tree))
    fmt_code = black.format_str(code, mode=black.FileMode())
    return (fmt_code, changed)


def has_print_in_non_main_functions(source) -> bool:
    for caller, _ in find_calls(source, ["print"]).items():
        if caller != "main":
            return True
    return False


def gather_solutions(sample_path: str, task_id: str) -> List[str]:
    """Gather all solutions from the folders"""
    solutions = []
    for model in os.listdir(sample_path):
        model_path = os.path.join(sample_path, model)
        if not os.path.isdir(model_path):
            continue
        task_path = os.path.join(model_path, task_id)
        if os.path.isdir(task_path):
            for file in os.listdir(task_path):
                if file.endswith(".py"):
                    with open(os.path.join(task_path, file), "r") as f:
                        solutions.append(f.read())
    return solutions


def deduplicate(solutions: List[str]) -> List[str]:
    """Deduplicate solutions"""
    asts = set()
    deduplicated = []
    for solution in solutions:
        solution = re.sub(r"#[^\n]*", "", solution)
        solution = re.sub(r'"""[^"]*"""', "", solution)
        solution = re.sub(r"'''[^']*'''", "", solution)
        try:
            ast_string = ast.dump(ast.parse(solution))
        except SyntaxError:
            continue
        except MemoryError:
            continue
        if ast_string not in asts:
            asts.add(ast_string)
            deduplicated.append(solution)
    return list(deduplicated)


def test_solutions(
    dataset: str, solutions: List[str], task: Dict, expected_output: List
) -> List[str]:
    """Test solutions, return functionally correct solutions"""
    n_workers = max(1, multiprocessing.cpu_count() // 2)
    correct_solution_ids = []

    with ProcessPoolExecutor(max_workers=n_workers) as executor:
        futures = [
            executor.submit(
                correctness_check, index, solution, dataset, task, expected_output
            )
            for index, solution in enumerate(solutions)
        ]
        for future in as_completed(futures):
            index, result, _ = future.result()
            if result[0] == PASS:
                correct_solution_ids.append(index)

    return [solutions[i] for i in correct_solution_ids]


def script(sample_dir: str, dataset: str = "humaneval", debug_task: str = None):
    assert dataset in ["humaneval", "mbpp"]
    if dataset == "humaneval":
        problems = get_human_eval_plus(noextreme=True)
        dataset_hash = get_human_eval_plus_hash(noextreme=True)
        expected_output = get_groundtruth(problems, dataset_hash, [])
    elif dataset == "mbpp":
        from evalplus.eval._special_oracle import MBPP_OUTPUT_NOT_NONE_TASKS

        problems = get_mbpp_plus(noextreme=True)
        dataset_hash = get_mbpp_plus_hash(noextreme=True)
        expected_output = get_groundtruth(
            problems,
            dataset_hash,
            MBPP_OUTPUT_NOT_NONE_TASKS,
        )

    previous_solutions = None

    for task_id, task in tqdm(problems.items()):
        if debug_task and task_id != debug_task:
            continue
        solutions = gather_solutions(sample_dir, task_id.replace("/", "_"))
        solutions = deduplicate(solutions)
        correct_solutions = test_solutions(
            dataset, solutions, task, expected_output[task_id]
        )

        # clean solutions to remove print statements and format it
        correct_solutions = [
            void_calls(solution, ["print"])[0] for solution in correct_solutions
        ]

        # Assuming that the previous solutions are correct
        if previous_solutions and task_id in previous_solutions:
            correct_solutions = deduplicate(
                correct_solutions + previous_solutions[task_id]
            )
        with open("solutions.jsonl", "a+") as f:
            f.write(
                json.dumps({"task_id": task_id, "solution": correct_solutions}) + "\n"
            )


def main():
    from fire import Fire

    Fire(script)


if __name__ == "__main__":
    main()