File size: 6,904 Bytes
77d636f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | import os
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader
def custom_collate(batch):
return {
'src_ids': torch.stack([torch.tensor(x['src_ids']) for x in batch]),
'src_mask': torch.stack([torch.tensor(x['src_mask']) for x in batch]),
'tgt_ids': torch.stack([torch.tensor(x['tgt_ids']) for x in batch]),
'tgt_mask': torch.stack([torch.tensor(x['tgt_mask']) for x in batch]),
# 保留测试用例用于验证 (仅 Eval 时有效)
'test_code': [x.get('test_code', "") for x in batch],
'entry_point': [x.get('entry_point', "") for x in batch]
}
def prepare_data(task_name, tokenizer, max_len, batch_size, split="train"):
"""
支持 split 参数,方便划分训练集和测试集
"""
print(f"Loading {task_name} ({split})...")
if task_name == "codexglue":
# 训练集:Microsoft CodeXGLUE (Python Refinement)
# 包含 GitHub Bug -> Fix
dataset = load_dataset("./code_x_glue_cc_code_refinement_full", "medium", split=split)
# 40k
if split == "train": dataset = dataset.select(range(40000))
# Case A: 标准修复数据 (有 source 和 target)
if 'source' in cols and 'target' in cols:
print(">> Detected standard refinement pairs.")
def preprocess_standard(ex):
src = tokenizer(ex['source'], max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(ex['target'], max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
}
preprocess_fn = preprocess_standard
# Case B: 只有代码 (有 code),需要人工注入 Bug
elif 'code' in cols:
print(">> Detected raw code. Not to inject synthetic bugs...")
else:
raise ValueError(f"Dataset columns {cols} not recognized. Need 'source'/'target' or 'code'.")
def preprocess(ex):
buggy = ex['source']
fixed = ex['target']
src = tokenizer(buggy, max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(fixed, max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
}
# 移除原始列
cols = dataset.column_names
elif task_name == "humanevalpack":
# 验证集:HumanEvalPack (Fix Task)
# 包含 Buggy Code 和 对应的 Unit Tests
dataset = load_dataset("./bigcode_humanevalpack_full", "python", split="test") # 只有 test 集
# 筛选出 FIX 任务
dataset = dataset.filter(lambda x: x['task_id'].startswith("Python/FIX"))
def preprocess(ex):
# prompt 是前面的描述,buggy_solution 是有 bug 的代码
# 为了简化,我们把 prompt + buggy_solution 作为输入
full_buggy = ex['prompt'] + "\n" + ex['buggy_solution']
full_fixed = ex['prompt'] + "\n" + ex['canonical_solution']
src = tokenizer(full_buggy, max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(full_fixed, max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'],
'test_code': ex['test'], # 核心:保留测试代码
'entry_point': ex['entry_point'] # 核心:保留入口函数名
}
# 保留所有列用于 debug,dataset.map 会自动处理返回的 dict
cols = [] # 不自动删除列,我们需要 test 列在 collate 中处理
# --- 1. Load Dataset ---
elif task_name == "wiki":
# 尝试本地加载,失败则下载
try:
dataset = load_dataset("./wikilarge-dataset")
except:
print("Local load failed, downloading from Hub...")
dataset = load_dataset("wikilarge")
# 手动划分: train用前10000条, test用后1000条 (做demo够了,全量太慢)
if split == "train":
dataset = dataset['train'].select(range(20000))
else:
# 假设总共有 ~290k,我们取后面一点做测试
dataset = dataset['train'].select(range(20000, 25000))
# 自动探测列名
cols = dataset.column_names
print(f"Wiki Dataset Columns: {cols}")
# 映射列名到 src/tgt
if 'src' in cols and 'dst' in cols:
src_key, tgt_key = 'src', 'dst'
elif 'Normal' in cols and 'Simple' in cols:
src_key, tgt_key = 'Normal', 'Simple'
else:
raise ValueError(f"Unknown column format for WikiLarge: {cols}")
def preprocess(ex):
# Source (Complex) -> Target (Simple)
src = tokenizer(ex[src_key], max_length=max_len, padding="max_length", truncation=True)
tgt = tokenizer(ex[tgt_key], max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
}
elif task_name == "mbpp":
dataset = load_dataset("mbpp", split="train[:500]")
print(f"MBPP Dataset Columns: {dataset.column_names}")
# MBPP 自重建任务: src=code, tgt=code
def preprocess(ex):
enc = tokenizer(ex['code'], max_length=max_len, padding="max_length", truncation=True)
return {
'src_ids': enc['input_ids'], 'src_mask': enc['attention_mask'],
'tgt_ids': enc['input_ids'], 'tgt_mask': enc['attention_mask']
}
else:
raise ValueError(f"Unknown task: {task_name}")
# --- 2. Map & Batch ---
print(f"Preprocessing {task_name} data...")
# 使用 remove_columns=dataset.column_names 确保删除所有原始列
print(f"Preprocessing {len(dataset)} examples...")
dataset = dataset.map(
preprocess,
batched=True,
remove_columns=dataset.column_names,
num_proc=4
)
# Test 集不 shuffle,方便对齐
return DataLoader(dataset, batch_size=batch_size, shuffle=(split=="train"), collate_fn=custom_collate) |