| import gzip |
| import itertools |
| import json |
| import os |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from typing import * |
|
|
| import numpy as np |
| from tqdm.auto import tqdm |
|
|
| from human_eval.data import stream_jsonl |
| from human_eval.execution import check_correctness |
|
|
| IMPORT_HELPER = { |
| "python": [ |
| "import math", |
| "import re", |
| "import sys", |
| "import copy", |
| "import datetime", |
| "import itertools", |
| "import collections", |
| "import heapq", |
| "import functools", |
| "import hashlib", |
| "import numpy", |
| "import numpy as np", |
| "import string", |
| "from typing import *", |
| "from collections import *", |
| ], |
| "go": [ |
| "math", |
| "strings", |
| "fmt", |
| "strconv", |
| "time", |
| "bytes", |
| "regexp", |
| "sort", |
| "math/rand", |
| "crypto/md5", |
| ], |
| "cpp": [ |
| "#include<stdlib.h>", |
| "#include<algorithm>", |
| "#include<math.h>", |
| "#include<stdio.h>", |
| "#include<vector>", |
| "#include<string>", |
| "#include<climits>", |
| "#include<cstring>", |
| "#include<iostream>", |
| "#include<cassert>", |
| ], |
| "cs": [ |
| "using System.Numerics;", |
| "using System.Diagnostics;", |
| "using System.Collections.Generic;", |
| "using System.Linq;", |
| "using System.Text;", |
| "using System.Security.Cryptography;", |
| "using System.Collections.Generic;", |
| ], |
| } |
|
|
|
|
| LANGUAGE_NAME = { |
| "cpp": "CPP", |
| "go": "Go", |
| "java": "Java", |
| "js": "JavaScript", |
| "python": "Python", |
| } |
|
|
|
|
| def read_dataset( |
| data_file: str = None, |
| dataset_type: str = "humaneval", |
| num_shot=None, |
| ) -> Dict: |
| """ |
| Reads a dataset and returns a dictionary of tasks. |
| """ |
| if num_shot is not None: |
| print(f"{num_shot}-shot setting...") |
| if "humaneval" in dataset_type.lower(): |
| if data_file is None: |
| current_path = os.path.dirname(os.path.abspath(__file__)) |
| data_file = os.path.join( |
| current_path, |
| "..", |
| "humaneval-x", |
| "python", |
| "data", |
| "humaneval_python.jsonl.gz", |
| ) |
| dataset = {task["task_id"]: task for task in stream_jsonl(data_file)} |
| else: |
| raise f"Dataset: {dataset_type} not supported." |
|
|
| return dataset |
|
|
|
|
| def estimate_pass_at_k( |
| num_samples: Union[int, List[int], np.ndarray], |
| num_correct: Union[List[int], np.ndarray], |
| k: int, |
| ) -> np.ndarray: |
| """ |
| Estimates pass@k of each problem and returns them in an array. |
| """ |
|
|
| def estimator(n: int, c: int, k: int) -> float: |
| """ |
| Calculates 1 - comb(n - c, k) / comb(n, k). |
| """ |
| if n - c < k: |
| return 1.0 |
| return 1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)) |
|
|
| if isinstance(num_samples, int): |
| num_samples_it = itertools.repeat(num_samples, len(num_correct)) |
| else: |
| assert len(num_samples) == len(num_correct) |
| num_samples_it = iter(num_samples) |
|
|
| return np.array( |
| [estimator(int(n), int(c), k) for n, c in zip(num_samples_it, num_correct)] |
| ) |
|
|
|
|
| def process_humaneval_test( |
| sample, problems, example_test=False, is_mbpp=False, language="python" |
| ): |
| """ |
| Processes a sample for evaluation. |
| """ |
| task_id = sample["task_id"] |
| if is_mbpp: |
| return sample["generation"] + "\n" + "\n".join(problems[task_id]["test"]) |
|
|
| prompt = sample["prompt"] |
| if ( |
| example_test |
| and "example_test" in problems[task_id] |
| and problems[task_id]["example_test"] != "" |
| ): |
| test = problems[task_id]["example_test"] |
| else: |
| test = problems[task_id]["test"] |
| code = sample["generation"] |
|
|
| |
| if language == "python": |
| test_setup = "\n".join(IMPORT_HELPER["python"]) + "\n" |
| test_string = test_setup + code + "\n" + test + "\n" |
| elif language == "cpp": |
| test_set_up = "" |
| for s in IMPORT_HELPER["cpp"]: |
| if s not in prompt: |
| test_set_up += s + "\n" |
| test_string = test_set_up + "\n" + code + "\n" + test |
| elif language == "java": |
| test_string = code + "\n" + test |
| elif language == "cs": |
| test_set_up = "" |
| for s in IMPORT_HELPER["cs"]: |
| test_set_up += s + "\n" |
| test_string = test_set_up + "\n" + code + "\n" + test |
| elif language in ["js", "javascript", "ts", "sh", "go"]: |
| test_string = code + "\n" + test |
| elif language == "go232": |
| import_string = problems[task_id]["import"] |
| prompt = prompt.replace(import_string, "") |
| if example_test and "example_test" in problems[task_id]: |
| test = problems[task_id]["example_test"] |
| else: |
| test = problems[task_id]["test"] |
| test_setup = problems[task_id]["test_setup"] |
| other_pkgs = [] |
| for pkg in IMPORT_HELPER["go"]: |
| if pkg not in test_setup: |
| p = pkg.split("/")[-1] |
| if p + "." in code: |
| other_pkgs.append(f'"{pkg}"') |
| if other_pkgs: |
| import_other_pkgs = ( |
| "import (\n" + " ".join([p + "\n" for p in other_pkgs]) + ")" |
| ) |
| test_string = ( |
| test_setup |
| + "\n" |
| + import_other_pkgs |
| + "\n" |
| + prompt |
| + code |
| + "\n" |
| + test |
| ) |
| else: |
| test_string = test_setup + "\n" + prompt + code + "\n" + test |
| elif language == "rust": |
| main = "\nfn main(){ \n } \n" |
| declaration = problems[task_id]["declaration"] |
| test_string = main + declaration + prompt + code + test |
| elif language == "php": |
| if code[:5] != "<?php": |
| code = "<?php\n" + code |
| test_string = code + "\n" + test + "?>" |
| return test_string |
|
|
|
|
| def stream_jsonl_all(filename: str) -> Iterable[Dict]: |
| """ |
| Streams a JSONL file. |
| """ |
| results = [] |
| if filename.endswith(".gz"): |
| fp = gzip.open(open(filename, "rb"), "rt") |
| else: |
| fp = open(filename, "r") |
| for line in fp: |
| if any(not x.isspace() for x in line): |
| results.append(json.loads(line)) |
| fp.close() |
|
|
| return results |
|
|
|
|
| def evaluate_functional_correctness( |
| input_file: str = None, |
| tmp_dir: str = "./", |
| n_workers: int = 32, |
| timeout: float = 10.0, |
| problem_file: str = "../data/humaneval_python.jsonl.gz", |
| out_path: str = None, |
| k: List[int] = [1, 10, 100], |
| test_groundtruth: bool = False, |
| example_test: bool = False, |
| is_mbpp: bool = False, |
| language: str = "python", |
| ): |
| """ |
| Evaluates the functional correctness of a model. |
| """ |
| if example_test: |
| print("Example test...") |
|
|
| problems = read_dataset(problem_file, dataset_type="humaneval") |
| sample_jsonl = stream_jsonl_all(input_file) |
|
|
| with ThreadPoolExecutor(max_workers=n_workers) as executor: |
|
|
| futures = [] |
| completion_id = Counter() |
| n_samples = 0 |
| |
| results = {} |
|
|
| if test_groundtruth: |
| print("Testing ground truth...") |
| for sample in tqdm(problems.values()): |
| task_id = sample["task_id"] |
| lang = task_id.split("/")[0].lower() |
| if lang == "javascript": |
| lang = "js" |
| tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") |
| sample["generation"] = sample["canonical_solution"] |
| sample["test_code"] = process_humaneval_test( |
| sample, problems, example_test, language |
| ) |
| if sample["test_code"] is None: |
| continue |
| args = ( |
| task_id, |
| sample, |
| lang, |
| timeout, |
| tmp_dir_, |
| completion_id[task_id], |
| ) |
| future = executor.submit(check_correctness, *args) |
| futures.append(future) |
| completion_id[task_id] += 1 |
| n_samples += 1 |
| else: |
| print("Reading samples...") |
| for sample in tqdm(sample_jsonl): |
| task_id = sample["task_id"] |
| if not is_mbpp: |
| lang = language |
| if not is_mbpp and lang == "javascript": |
| lang = "js" |
| if is_mbpp: |
| lang = "python" |
| tmp_dir_ = os.path.join(tmp_dir, lang, "evaluation") |
| sample["task_id"] = task_id |
| sample["test_code"] = process_humaneval_test( |
| sample, problems, example_test, is_mbpp, language |
| ) |
| if sample["test_code"] is None: |
| continue |
| if "completion_id" in sample: |
| completion_id_ = sample["completion_id"] |
| else: |
| completion_id_ = completion_id[task_id] |
| args = (task_id, sample, lang, timeout, tmp_dir_, completion_id_) |
| future = executor.submit(check_correctness, *args) |
| futures.append(future) |
| completion_id[task_id] += 1 |
| n_samples += 1 |
|
|
| if len(completion_id) == len(problems): |
| evaluate_pass_at_k = True |
| else: |
| evaluate_pass_at_k = False |
|
|
| print("Running test suites...") |
| for future in tqdm(as_completed(futures), total=len(futures)): |
| result = future.result() |
| |
| results[result["task_id"]] = result |
|
|
| |
| total, correct = [], [] |
| for result in results.values(): |
| |
| passed = [result["passed"]] |
| total.append(len(passed)) |
| correct.append(sum(passed)) |
| total = np.array(total) |
| correct = np.array(correct) |
|
|
| if evaluate_pass_at_k: |
| ks = k |
| pass_at_k = { |
| f"pass@{k}": estimate_pass_at_k(total, correct, k).mean() |
| for k in ks |
| if (total >= k).all() |
| } |
| print(pass_at_k) |
| else: |
| print("Total:", np.sum(total)) |
| print("Correct:", np.sum(correct)) |
|
|
| if out_path: |
| with open(out_path, "w") as f: |
| json.dump(list(results.values()), f, ensure_ascii=False) |
|
|
| return pass_at_k |
|
|