| | import os |
| | import pickle |
| | import sys |
| | from importlib import import_module |
| | from io import StringIO |
| | from typing import Any, Dict, List |
| |
|
| | import coverage |
| | from rich.progress import track |
| |
|
| | from evalplus.eval.utils import swallow_io |
| | from tools.tsr.utils import get_problems, get_task_ids, to_path |
| |
|
| |
|
| | class Capturing(list): |
| | def __enter__(self): |
| | self._stdout = sys.stdout |
| | sys.stdout = self._stringio = StringIO() |
| | return self |
| |
|
| | def __exit__(self, *args): |
| | self.extend(self._stringio.getvalue().splitlines()) |
| | del self._stringio |
| | sys.stdout = self._stdout |
| |
|
| |
|
| | def parse_lcov(outputs: List[str]): |
| | switch, extracted_outputs = False, [] |
| | for line in outputs: |
| | if switch == False and "tmp_src" in line: |
| | switch = True |
| | if switch == True and "end_of_record" in line: |
| | switch = False |
| | if switch: |
| | extracted_outputs.append(line) |
| |
|
| | branch, branch_covered = [], [] |
| | for line in extracted_outputs: |
| | if line.startswith("BRDA"): |
| | |
| | lineno, blockno, branchno, taken = line[5:].split(",") |
| | branch_sig = f"BR:{lineno},{blockno},{branchno}" |
| | branch.append(branch_sig) |
| | if taken not in ["0", "-"]: |
| | branch_covered.append(branch_sig) |
| | per = 1.0 if len(branch) == 0 else len(branch_covered) / len(branch) |
| | return per, branch, branch_covered |
| |
|
| |
|
| | def test_code_coverage( |
| | identifier: str, code: str, inputs: List[List[Any]], entry_point: str |
| | ): |
| | module_name = f"tmp_src_{identifier}" |
| | with open(f"{module_name}.py", "w") as f: |
| | f.write(code) |
| |
|
| | mod = import_module(module_name) |
| | func = getattr(mod, entry_point, None) |
| | assert func != None, f"entry_point = {entry_point} not exist, code: {code}" |
| |
|
| | cov = coverage.Coverage(branch=True) |
| | cov.start() |
| | with swallow_io(): |
| | for input_list in inputs: |
| | func(*input_list) |
| | cov.stop() |
| | with Capturing() as outputs: |
| | cov.lcov_report(outfile="-") |
| |
|
| | ret = parse_lcov(outputs) |
| |
|
| | os.remove(f"{module_name}.py") |
| | return ret |
| |
|
| |
|
| | def collect_coverage_info(coverage_dir: str, dataset: str) -> Dict[str, Dict[str, Any]]: |
| | os.makedirs(coverage_dir, exist_ok=True) |
| | problems = get_problems(dataset) |
| | task_ids = get_task_ids(dataset) |
| | coverage_info = {task_id: {} for task_id in task_ids} |
| | for task_id in track(task_ids, description="Testing gt coverage..."): |
| | coverage_cache_path = os.path.join(coverage_dir, f"{to_path(task_id)}.pkl") |
| | if os.path.isfile(coverage_cache_path): |
| | with open(coverage_cache_path, "rb") as f: |
| | coverage_info[task_id] = pickle.load(f) |
| | continue |
| | groundtruth_code = ( |
| | problems[task_id]["prompt"] + problems[task_id]["canonical_solution"] |
| | ) |
| | plus_tests = problems[task_id]["plus_input"] |
| | entry_point = problems[task_id]["entry_point"] |
| | for i, plus_test in enumerate(plus_tests): |
| | per, branch, branch_covered = test_code_coverage( |
| | to_path(task_id), groundtruth_code, [plus_test], entry_point |
| | ) |
| | test_id = f"plus_{i}" |
| | coverage_info[task_id].setdefault(test_id, []).extend( |
| | [(br, "gt") for br in branch_covered] |
| | ) |
| | with open(coverage_cache_path, "wb") as f: |
| | pickle.dump(coverage_info[task_id], f) |
| |
|
| | return coverage_info |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--dataset", type=str, choices=["humaneval", "mbpp"]) |
| | parser.add_argument("--report_dir", required=True, type=str) |
| | args = parser.parse_args() |
| |
|
| | coverage_dir = os.path.join(args.report_dir, "coverage_cache") |
| | collect_coverage_info(coverage_dir, args.dataset) |
| |
|