Spaces:
Runtime error
Runtime error
| import argparse | |
| import os | |
| from os import PathLike | |
| from model import DecoderBase, make_model | |
| from rich.progress import ( | |
| BarColumn, | |
| MofNCompleteColumn, | |
| Progress, | |
| TextColumn, | |
| TimeElapsedColumn, | |
| ) | |
| def construct_contract_prompt(prompt: str, contract_type: str, contract: str) -> str: | |
| if contract_type == "none": | |
| return prompt | |
| elif contract_type == "docstring": | |
| # embed within the docstring | |
| sep = "" | |
| if '"""' in prompt: | |
| sep = '"""' | |
| elif "'''" in prompt: | |
| sep = "'''" | |
| assert sep != "" | |
| l = prompt.split(sep) | |
| contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) | |
| l[1] = ( | |
| l[1] + contract + "\n" + " " * (len(contract) - len(contract.lstrip()) - 1) | |
| ) | |
| return sep.join(l) | |
| elif contract_type == "code": | |
| # at the beginning of the function | |
| contract = "\n".join([x.split("#")[0] for x in contract.splitlines()]) | |
| return prompt + contract | |
| def code_generate(args, workdir: PathLike, model: DecoderBase, id_range=None): | |
| with Progress( | |
| TextColumn( | |
| f"{args.dataset} •" + "[progress.percentage]{task.percentage:>3.0f}%" | |
| ), | |
| BarColumn(), | |
| MofNCompleteColumn(), | |
| TextColumn("•"), | |
| TimeElapsedColumn(), | |
| ) as p: | |
| if args.dataset == "humaneval": | |
| from evalplus.data import get_human_eval_plus | |
| dataset = get_human_eval_plus() | |
| elif args.dataset == "mbpp": | |
| from evalplus.data import get_mbpp_plus | |
| dataset = get_mbpp_plus() | |
| for task_id, task in p.track(dataset.items()): | |
| if id_range is not None: | |
| id_num = int(task_id.split("/")[1]) | |
| low, high = id_range | |
| if id_num < low or id_num >= high: | |
| p.console.print(f"Skipping {task_id} as it is not in {id_range}") | |
| continue | |
| p_name = task_id.replace("/", "_") | |
| if args.contract_type != "none" and task["contract"] == "": | |
| continue | |
| os.makedirs(os.path.join(workdir, p_name), exist_ok=True) | |
| log = f"Codegen: {p_name} @ {model}" | |
| n_existing = 0 | |
| if args.resume: | |
| # count existing .py files | |
| n_existing = len( | |
| [ | |
| f | |
| for f in os.listdir(os.path.join(workdir, p_name)) | |
| if f.endswith(".py") | |
| ] | |
| ) | |
| if n_existing > 0: | |
| log += f" (resuming from {n_existing})" | |
| nsamples = args.n_samples - n_existing | |
| p.console.print(log) | |
| sidx = args.n_samples - nsamples | |
| while sidx < args.n_samples: | |
| outputs = model.codegen( | |
| construct_contract_prompt( | |
| task["prompt"], args.contract_type, task["contract"] | |
| ), | |
| do_sample=not args.greedy, | |
| num_samples=args.n_samples - sidx, | |
| ) | |
| assert outputs, "No outputs from model!" | |
| for impl in outputs: | |
| try: | |
| with open( | |
| os.path.join(workdir, p_name, f"{sidx}.py"), | |
| "w", | |
| encoding="utf-8", | |
| ) as f: | |
| if model.conversational: | |
| f.write(impl) | |
| else: | |
| f.write(task["prompt"] + impl) | |
| except UnicodeEncodeError: | |
| continue | |
| sidx += 1 | |
| def main(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", required=True, type=str) | |
| parser.add_argument("--bs", default=1, type=int) | |
| parser.add_argument("--temperature", default=0.0, type=float) | |
| parser.add_argument( | |
| "--dataset", required=True, type=str, choices=["humaneval", "mbpp"] | |
| ) | |
| parser.add_argument("--root", type=str, required=True) | |
| parser.add_argument("--n_samples", default=1, type=int) | |
| parser.add_argument("--resume", action="store_true") | |
| parser.add_argument( | |
| "--contract-type", | |
| default="none", | |
| type=str, | |
| choices=["none", "code", "docstring"], | |
| ) | |
| parser.add_argument("--greedy", action="store_true") | |
| # id_range is list | |
| parser.add_argument("--id-range", default=None, nargs="+", type=int) | |
| args = parser.parse_args() | |
| if args.greedy and (args.temperature != 0 or args.bs != 1 or args.n_samples != 1): | |
| args.temperature = 0 | |
| args.bs = 1 | |
| args.n_samples = 1 | |
| print("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0") | |
| if args.id_range is not None: | |
| assert len(args.id_range) == 2, "id_range must be a list of length 2" | |
| assert args.id_range[0] < args.id_range[1], "id_range must be increasing" | |
| args.id_range = tuple(args.id_range) | |
| # Make project dir | |
| os.makedirs(args.root, exist_ok=True) | |
| # Make dataset dir | |
| os.makedirs(os.path.join(args.root, args.dataset), exist_ok=True) | |
| # Make dir for codes generated by each model | |
| args.model = args.model.lower() | |
| model = make_model( | |
| name=args.model, batch_size=args.bs, temperature=args.temperature | |
| ) | |
| workdir = os.path.join( | |
| args.root, | |
| args.dataset, | |
| args.model | |
| + f"_temp_{args.temperature}" | |
| + ("" if args.contract_type == "none" else f"-contract-{args.contract_type}"), | |
| ) | |
| os.makedirs(workdir, exist_ok=True) | |
| with open(os.path.join(workdir, "args.txt"), "w") as f: | |
| f.write(str(args)) | |
| code_generate(args, workdir=workdir, model=model, id_range=args.id_range) | |
| if __name__ == "__main__": | |
| main() | |