| import json | |
| import os | |
| import numpy as np | |
| class MBPPDataset: | |
| def __init__(self, root, samplenum=1): | |
| """ | |
| root: 数据文件的根目录 | |
| """ | |
| self.root = root | |
| self.data = open(os.path.join(root, "mbpp.jsonl")).readlines() | |
| self.clean_data = self.get_qa_only_data(self.data) | |
| self.prompt = [] | |
| for i in range(1, 4): | |
| prompt = self.clean_data[i]["prompt"] | |
| tests = "\n".join(self.clean_data[i]["test"]) | |
| code = self.clean_data[i]["code"].replace("\r", "").replace("\t", " ") | |
| prompt1 = f"You are an expert Python programmer, and here is your task: {prompt} Your code should pass these tests:\n\n{tests}\n[BEGIN]\n{code}\n[DONE]\n" | |
| if len(self.prompt) == 0: | |
| self.prompt.append(prompt1) | |
| else: | |
| self.prompt.append(self.prompt[-1] + prompt1) | |
| self.testdata = [] | |
| for i in range(10, 510): | |
| for j in range(samplenum): | |
| self.testdata.append(self.clean_data[i]) | |
| np.random.seed(1234) | |
| print(f"Read MBPP from {root}, number of samples {len(self.testdata)}") | |
| def get_qa_only_data(self, data_json): | |
| ans = [] | |
| for line in data_json: | |
| line = json.loads(line) | |
| prompt = line["text"] | |
| suffix = line["test_list"] | |
| code = line["code"] | |
| ans.append( | |
| { | |
| "prompt": prompt, | |
| "test": suffix, | |
| "code": code, | |
| "task_id": line["task_id"], | |
| } | |
| ) | |
| return ans | |
| def __len__(self): | |
| return len(self.testdata) | |
| def __getitem__(self, index): | |
| sample = self.testdata[index] | |
| return sample | |