File size: 5,966 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import argparse
import importlib
import inspect
import multiprocessing
import os
import sys
from io import StringIO
from typing import Any, Callable, List, Union

import coverage

from evalplus.data import get_human_eval_plus
from evalplus.data.utils import to_raw
from evalplus.eval.utils import reliability_guard, swallow_io, time_limit


def construct_inputs_sig(inputs: list) -> str:
    str_builder = ""
    for x in inputs:
        if type(x) == str:
            str_builder += f"'{to_raw(x)}',"
        else:
            str_builder += f"{x},"
    return str_builder[:-1]


class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self

    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio
        sys.stdout = self._stdout


def parse_lcov(outputs: List[str], func: Callable, mode: str = "branch"):
    switch, extracted_outputs = False, []
    for line in outputs:
        if switch == False and "tmp_src" in line:
            switch = True
        if switch == True and "end_of_record" in line:
            switch = False
        if switch:
            extracted_outputs.append(line)

    src, start_lineno = inspect.getsourcelines(func)
    end_lineno = start_lineno + len(src) - 1

    if mode == "branch":
        branch, branch_covered = [], []
        for line in extracted_outputs:
            if line.startswith("BRDA"):
                # BRDA format: BR:<lineno>,<blockno>,<branchno>,<taken>
                lineno, blockno, branchno, taken = line[5:].split(",")
                branch_sig = f"BR:{lineno},{blockno},{branchno}"
                branch.append(branch_sig)
                if taken not in ["0", "-"]:
                    branch_covered.append(branch_sig)
        per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch)
        return per, branch, branch_covered
    else:
        not_covered_lines = []
        for line in extracted_outputs:
            if line.startswith("DA"):
                # DA format: DA:<lineno>,<exec_count>[,...]
                lineno, exec_count = line[3:].split(",")[:2]
                if start_lineno <= int(lineno) <= end_lineno:
                    if exec_count == "0":
                        not_covered_lines.append(int(lineno))
        for lineno in not_covered_lines:
            line = src[lineno - start_lineno]
            if line.strip() != "" and "def" not in line:
                src[lineno - start_lineno] = line[:-1] + "  # Not executed\n"
        return "".join(src)


def test_code_coverage(
    code: str, inputs: List[List[Any]], entry_point: str, mode="branch"
):
    def safety_test(code: str, inputs: List[List[Any]], entry_point: str):
        for input_list in inputs:
            code += f"{entry_point}({construct_inputs_sig(input_list)})\n"
        reliability_guard()
        try:
            with swallow_io():
                with time_limit(1):
                    exec(code, {})
        except:
            sys.exit(1)

    p = multiprocessing.Process(target=safety_test, args=(code, inputs, entry_point))
    p.start()
    p.join()
    safe = p.exitcode == 0
    if p.is_alive():
        p.terminate()
        p.kill()
    if not safe:
        print("Potentially dangerous code, refuse coverage test.")
        return None

    with open("tmp_src.py", "w") as f:
        f.write(code)
    import tmp_src

    importlib.reload(tmp_src)
    func = getattr(tmp_src, f"{entry_point}", None)
    assert func != None, f"{entry_point = } not exist"

    cov = coverage.Coverage(branch=True)
    cov.start()
    with swallow_io():
        for input_list in inputs:
            func(*input_list)
    cov.stop()
    with Capturing() as outputs:
        cov.lcov_report(outfile="-")

    ret = parse_lcov(outputs, func, mode)

    os.remove("tmp_src.py")
    return ret


def test_solution_coverage(
    dataset: str = "HumanEvalPlus",
    task_id: str = "HumanEval/0",
    impl: str = "canonical",
    inputs: Union[str, List[List[Any]]] = "base_input",
    mode: str = "branch",
):
    """
    Parameters:
    * dataset: {None, "HumanEval", "HumanEvalPlus"}
    * task_id: ralated to dataset
    * impl: {"canonical", source code}
    * inputs: {"base_inputs", list}
    * mode: {"branch"}, will support "line" for coverage-guided LLM test generation
    """
    if "HumanEval" in dataset:
        problems, problem = get_human_eval_plus(), None
        for p in problems:
            if p["task_id"] == task_id:
                problem = p
        assert problem != None, f"invalid {task_id = }"
        entry_point = problem["entry_point"]
        code = problem["prompt"] + (
            impl if impl != "canonical" else problem["canonical_solution"]
        )
        if inputs == "base_input":
            inputs = problem["base_input"]
    else:
        raise NotImplementedError

    return test_code_coverage(code, inputs, entry_point, mode)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--mode", type=str, default="branch", choices=["line", "branch"]
    )
    args = parser.parse_args()

    if args.mode == "branch":
        for i in range(0, 164):
            task_id = f"HumanEval/{i}"
            branch, branch_covered = test_solution_coverage(
                dataset="HumanEval", task_id=task_id, mode="branch"
            )
            per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch)
            if per != 1.0:
                print(i, per, len(branch_covered), len(branch))
    else:
        for i in range(0, 164):
            task_id = f"HumanEval/{i}"
            annotated_code = test_solution_coverage(
                dataset="HumanEval", task_id=task_id, mode="line"
            )
            if "Not executed" in annotated_code:
                print(f"{task_id = }")
                print(annotated_code)