import argparse import importlib import inspect import multiprocessing import os import sys from io import StringIO from typing import Any, Callable, List, Union import coverage from evalplus.data import get_human_eval_plus from evalplus.data.utils import to_raw from evalplus.eval.utils import reliability_guard, swallow_io, time_limit def construct_inputs_sig(inputs: list) -> str: str_builder = "" for x in inputs: if type(x) == str: str_builder += f"'{to_raw(x)}'," else: str_builder += f"{x}," return str_builder[:-1] class Capturing(list): def __enter__(self): self._stdout = sys.stdout sys.stdout = self._stringio = StringIO() return self def __exit__(self, *args): self.extend(self._stringio.getvalue().splitlines()) del self._stringio sys.stdout = self._stdout def parse_lcov(outputs: List[str], func: Callable, mode: str = "branch"): switch, extracted_outputs = False, [] for line in outputs: if switch == False and "tmp_src" in line: switch = True if switch == True and "end_of_record" in line: switch = False if switch: extracted_outputs.append(line) src, start_lineno = inspect.getsourcelines(func) end_lineno = start_lineno + len(src) - 1 if mode == "branch": branch, branch_covered = [], [] for line in extracted_outputs: if line.startswith("BRDA"): # BRDA format: BR:,,, lineno, blockno, branchno, taken = line[5:].split(",") branch_sig = f"BR:{lineno},{blockno},{branchno}" branch.append(branch_sig) if taken not in ["0", "-"]: branch_covered.append(branch_sig) per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch) return per, branch, branch_covered else: not_covered_lines = [] for line in extracted_outputs: if line.startswith("DA"): # DA format: DA:,[,...] lineno, exec_count = line[3:].split(",")[:2] if start_lineno <= int(lineno) <= end_lineno: if exec_count == "0": not_covered_lines.append(int(lineno)) for lineno in not_covered_lines: line = src[lineno - start_lineno] if line.strip() != "" and "def" not in line: src[lineno - start_lineno] = line[:-1] + " # Not executed\n" return "".join(src) def test_code_coverage( code: str, inputs: List[List[Any]], entry_point: str, mode="branch" ): def safety_test(code: str, inputs: List[List[Any]], entry_point: str): for input_list in inputs: code += f"{entry_point}({construct_inputs_sig(input_list)})\n" reliability_guard() try: with swallow_io(): with time_limit(1): exec(code, {}) except: sys.exit(1) p = multiprocessing.Process(target=safety_test, args=(code, inputs, entry_point)) p.start() p.join() safe = p.exitcode == 0 if p.is_alive(): p.terminate() p.kill() if not safe: print("Potentially dangerous code, refuse coverage test.") return None with open("tmp_src.py", "w") as f: f.write(code) import tmp_src importlib.reload(tmp_src) func = getattr(tmp_src, f"{entry_point}", None) assert func != None, f"{entry_point = } not exist" cov = coverage.Coverage(branch=True) cov.start() with swallow_io(): for input_list in inputs: func(*input_list) cov.stop() with Capturing() as outputs: cov.lcov_report(outfile="-") ret = parse_lcov(outputs, func, mode) os.remove("tmp_src.py") return ret def test_solution_coverage( dataset: str = "HumanEvalPlus", task_id: str = "HumanEval/0", impl: str = "canonical", inputs: Union[str, List[List[Any]]] = "base_input", mode: str = "branch", ): """ Parameters: * dataset: {None, "HumanEval", "HumanEvalPlus"} * task_id: ralated to dataset * impl: {"canonical", source code} * inputs: {"base_inputs", list} * mode: {"branch"}, will support "line" for coverage-guided LLM test generation """ if "HumanEval" in dataset: problems, problem = get_human_eval_plus(), None for p in problems: if p["task_id"] == task_id: problem = p assert problem != None, f"invalid {task_id = }" entry_point = problem["entry_point"] code = problem["prompt"] + ( impl if impl != "canonical" else problem["canonical_solution"] ) if inputs == "base_input": inputs = problem["base_input"] else: raise NotImplementedError return test_code_coverage(code, inputs, entry_point, mode) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--mode", type=str, default="branch", choices=["line", "branch"] ) args = parser.parse_args() if args.mode == "branch": for i in range(0, 164): task_id = f"HumanEval/{i}" branch, branch_covered = test_solution_coverage( dataset="HumanEval", task_id=task_id, mode="branch" ) per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch) if per != 1.0: print(i, per, len(branch_covered), len(branch)) else: for i in range(0, 164): task_id = f"HumanEval/{i}" annotated_code = test_solution_coverage( dataset="HumanEval", task_id=task_id, mode="line" ) if "Not executed" in annotated_code: print(f"{task_id = }") print(annotated_code)