Spaces:
Runtime error
Runtime error
codellama-CodeLlama-7b-hf
/
Llama2-Code-Interpreter-main
/OpenCodeInterpreter
/evaluation
/evalplus
/tools
/mbpp
/init_plus.py
| import json | |
| import os | |
| import pathlib | |
| import shutil | |
| from importlib import util | |
| from inspect import getmembers, isfunction | |
| from typing import Tuple | |
| from tempdir import TempDir | |
| from evalplus.data.mbpp import get_mbpp, mbpp_serialize_inputs | |
| MBPP_PLUS_PATH = pathlib.Path(__file__).parent.parent.parent / "MbppBase.jsonl" | |
| GROUNDTRUTH_MBPP_PATH = pathlib.Path(__file__).parent.parent.parent / "groundtruth/mbpp" | |
| def _ret(entry_point) -> str: | |
| """ | |
| This is a hacky function to return some garbages so that we can | |
| successfully run the function . | |
| """ | |
| set_assertion_func = [ | |
| "similar_elements", | |
| "find_char_long", | |
| "common_in_nested_lists", | |
| "extract_singly", | |
| "larg_nnum", | |
| "intersection_array", | |
| "k_smallest_pairs", | |
| ] | |
| if entry_point in set_assertion_func: | |
| return "()" | |
| return "1" | |
| def get_entry_point(task_id: int, assertion: str) -> str: | |
| py_file_path = str(GROUNDTRUTH_MBPP_PATH) + f"/{str(task_id).zfill(3)}.py" | |
| spec = util.spec_from_file_location("inspect_module", py_file_path) | |
| module = util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| functions = [name for name, value in getmembers(module, isfunction)] | |
| # maybe exist some helper functions, filter them | |
| functions = [func for func in functions if func in assertion] | |
| if len(functions) > 1: | |
| print("more than one function: ", functions) | |
| return functions[0] if len(functions) > 0 else None | |
| def get_code_and_contract_and_assertion(task_id: id) -> Tuple[str, str, str]: | |
| py_file_path = str(GROUNDTRUTH_MBPP_PATH) + f"/{str(task_id).zfill(3)}.py" | |
| with open(py_file_path) as reader: | |
| text = reader.read() | |
| # remove docstring | |
| start_index = text.find('"""') | |
| end_index = text.find('"""', start_index + 3) | |
| if start_index != -1 and end_index != -1: | |
| text = text[:start_index] + text[end_index + 3 :] | |
| lines = text.splitlines() | |
| assertion = "" | |
| contract = "" | |
| for i in range(len(lines)): | |
| if "$_CONTRACT_$" in lines[i]: | |
| contract += lines[i] + "\n" | |
| elif lines[i].startswith("assert"): | |
| assertion += lines[i] + "\n" | |
| for i in range(len(lines) - 1, -1, -1): | |
| if ( | |
| "$_CONTRACT_$" in lines[i] | |
| or lines[i].startswith("assert") | |
| or lines[i] == "" | |
| ): | |
| del lines[i] | |
| for i in range(len(lines) - 1, -1, -1): | |
| if lines[i].startswith("import"): | |
| del lines[i] | |
| else: | |
| break | |
| code = "\n".join(lines) | |
| return "\n" + code + "\n", "\n" + contract, "\n" + assertion | |
| def instrument_inputs(code, entry_point, test_code) -> str: | |
| globals()["_inputs"] = [] | |
| fn_text = f"""{code.split(f"def {entry_point}")[0]} | |
| def {entry_point}(*args): | |
| _inputs.append(args) | |
| return {_ret(entry_point)} | |
| """ | |
| exec(fn_text + "\n" + test_code.replace("assert ", ""), globals()) | |
| print(fn_text + "\n" + test_code.replace("assert ", "")) | |
| print(globals()["_inputs"]) | |
| return globals()["_inputs"] | |
| def get_atol(task_id: int) -> float: | |
| float_ans_list = [ | |
| 82, | |
| 85, | |
| 98, | |
| 120, | |
| 124, | |
| 137, | |
| 139, | |
| 163, | |
| 233, | |
| 246, | |
| 248, | |
| 276, | |
| 293, | |
| 300, | |
| 312, | |
| 442, | |
| 574, | |
| 742, | |
| 746, | |
| ] | |
| if task_id in float_ans_list: | |
| return 1e-4 | |
| return 0 | |
| if __name__ == "__main__": | |
| assert not MBPP_PLUS_PATH.exists(), f"{MBPP_PLUS_PATH} already exists!" | |
| mbpp = get_mbpp() | |
| with TempDir() as temp_dir: | |
| tmp_file = os.path.join(temp_dir, MBPP_PLUS_PATH) | |
| with open(tmp_file, "w") as writer: | |
| for task in mbpp.values(): | |
| task_id = int(task["task_id"]) | |
| if task_id in [ | |
| 163, | |
| 228, | |
| 304, | |
| 408, | |
| 776, | |
| 307, | |
| 417, | |
| 443, | |
| 444, | |
| 452, | |
| 464, | |
| 617, | |
| 627, | |
| 738, | |
| 747, | |
| 802, | |
| 393, | |
| 411, | |
| 584, | |
| 625, | |
| 756, | |
| 779, | |
| ]: | |
| continue | |
| task["task_id"] = f"Mbpp/{task_id}" | |
| task["entry_point"] = get_entry_point(task_id, task["test_list"][0]) | |
| task["prompt"] = f'"""\n{task["prompt"]}\n{task["test_list"][0]}\n"""\n' | |
| ( | |
| task["canonical_solution"], | |
| task["contract"], | |
| task["assertion"], | |
| ) = get_code_and_contract_and_assertion(task_id) | |
| if len(task["test_imports"]): | |
| task["assertion"] = ( | |
| "\n".join(task["test_imports"]) + "\n" + task["assertion"] | |
| ) | |
| task["base_input"] = instrument_inputs( | |
| task["canonical_solution"], task["entry_point"], task["assertion"] | |
| ) | |
| task["atol"] = get_atol(task_id) | |
| del task["source_file"] | |
| del task["code"] | |
| del task["test_list"] | |
| del task["test_imports"] | |
| del task["assertion"] | |
| task["base_input"] = mbpp_serialize_inputs(task_id, task["base_input"]) | |
| writer.write(json.dumps(task) + "\n") | |
| shutil.copy2(tmp_file, MBPP_PLUS_PATH) | |