File size: 6,904 Bytes
77d636f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
from datasets import load_dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader

def custom_collate(batch):
    return {
        'src_ids': torch.stack([torch.tensor(x['src_ids']) for x in batch]),
        'src_mask': torch.stack([torch.tensor(x['src_mask']) for x in batch]),
        'tgt_ids': torch.stack([torch.tensor(x['tgt_ids']) for x in batch]),
        'tgt_mask': torch.stack([torch.tensor(x['tgt_mask']) for x in batch]),
        # 保留测试用例用于验证 (仅 Eval 时有效)
        'test_code': [x.get('test_code', "") for x in batch],
        'entry_point': [x.get('entry_point', "") for x in batch]
    }

def prepare_data(task_name, tokenizer, max_len, batch_size, split="train"):
    """
    支持 split 参数,方便划分训练集和测试集
    """
    print(f"Loading {task_name} ({split})...")
    
    if task_name == "codexglue":
        # 训练集:Microsoft CodeXGLUE (Python Refinement)
        # 包含 GitHub Bug -> Fix
        dataset = load_dataset("./code_x_glue_cc_code_refinement_full", "medium", split=split)

        # 40k
        if split == "train": dataset = dataset.select(range(40000)) 

        # Case A: 标准修复数据 (有 source 和 target)
        if 'source' in cols and 'target' in cols:
            print(">> Detected standard refinement pairs.")
            def preprocess_standard(ex):
                src = tokenizer(ex['source'], max_length=max_len, padding="max_length", truncation=True)
                tgt = tokenizer(ex['target'], max_length=max_len, padding="max_length", truncation=True)
                return {
                    'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
                    'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
                }
            preprocess_fn = preprocess_standard

        # Case B: 只有代码 (有 code),需要人工注入 Bug
        elif 'code' in cols:
            print(">> Detected raw code. Not to inject synthetic bugs...")
            
        else:
            raise ValueError(f"Dataset columns {cols} not recognized. Need 'source'/'target' or 'code'.")
        
        def preprocess(ex):
            buggy = ex['source']
            fixed = ex['target']
            
            src = tokenizer(buggy, max_length=max_len, padding="max_length", truncation=True)
            tgt = tokenizer(fixed, max_length=max_len, padding="max_length", truncation=True)
            
            return {
                'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
                'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
            }
        
        # 移除原始列
        cols = dataset.column_names
    
    elif task_name == "humanevalpack":
        # 验证集:HumanEvalPack (Fix Task)
        # 包含 Buggy Code 和 对应的 Unit Tests
        dataset = load_dataset("./bigcode_humanevalpack_full", "python", split="test") # 只有 test 集
        
        # 筛选出 FIX 任务
        dataset = dataset.filter(lambda x: x['task_id'].startswith("Python/FIX"))
        
        def preprocess(ex):
            # prompt 是前面的描述,buggy_solution 是有 bug 的代码
            # 为了简化,我们把 prompt + buggy_solution 作为输入
            full_buggy = ex['prompt'] + "\n" + ex['buggy_solution']
            full_fixed = ex['prompt'] + "\n" + ex['canonical_solution']
            
            src = tokenizer(full_buggy, max_length=max_len, padding="max_length", truncation=True)
            tgt = tokenizer(full_fixed, max_length=max_len, padding="max_length", truncation=True)
            
            return {
                'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
                'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask'],
                'test_code': ex['test'],         # 核心:保留测试代码
                'entry_point': ex['entry_point'] # 核心:保留入口函数名
            }
        
        # 保留所有列用于 debug,dataset.map 会自动处理返回的 dict
        cols = [] # 不自动删除列,我们需要 test 列在 collate 中处理
    
    # --- 1. Load Dataset ---
    elif task_name == "wiki":
         # 尝试本地加载,失败则下载
        try:
            dataset = load_dataset("./wikilarge-dataset")
        except:
            print("Local load failed, downloading from Hub...")
            dataset = load_dataset("wikilarge")
        
        # 手动划分: train用前10000条, test用后1000条 (做demo够了,全量太慢)
        if split == "train":
            dataset = dataset['train'].select(range(20000))
        else:
            # 假设总共有 ~290k,我们取后面一点做测试
            dataset = dataset['train'].select(range(20000, 25000))

        # 自动探测列名
        cols = dataset.column_names
        print(f"Wiki Dataset Columns: {cols}")
        
        # 映射列名到 src/tgt
        if 'src' in cols and 'dst' in cols:
            src_key, tgt_key = 'src', 'dst'
        elif 'Normal' in cols and 'Simple' in cols:
            src_key, tgt_key = 'Normal', 'Simple'
        else:
            raise ValueError(f"Unknown column format for WikiLarge: {cols}")

        def preprocess(ex):
            # Source (Complex) -> Target (Simple)
            src = tokenizer(ex[src_key], max_length=max_len, padding="max_length", truncation=True)
            tgt = tokenizer(ex[tgt_key], max_length=max_len, padding="max_length", truncation=True)
            return {
                'src_ids': src['input_ids'], 'src_mask': src['attention_mask'],
                'tgt_ids': tgt['input_ids'], 'tgt_mask': tgt['attention_mask']
            }
        
    elif task_name == "mbpp":
        dataset = load_dataset("mbpp", split="train[:500]")
        print(f"MBPP Dataset Columns: {dataset.column_names}")
        
        # MBPP 自重建任务: src=code, tgt=code
        def preprocess(ex):
            enc = tokenizer(ex['code'], max_length=max_len, padding="max_length", truncation=True)
            return {
                'src_ids': enc['input_ids'], 'src_mask': enc['attention_mask'],
                'tgt_ids': enc['input_ids'], 'tgt_mask': enc['attention_mask']
            }
    
    else:
        raise ValueError(f"Unknown task: {task_name}")

    # --- 2. Map & Batch ---
    print(f"Preprocessing {task_name} data...")
    # 使用 remove_columns=dataset.column_names 确保删除所有原始列
    print(f"Preprocessing {len(dataset)} examples...")
    dataset = dataset.map(
        preprocess, 
        batched=True, 
        remove_columns=dataset.column_names, 
        num_proc=4
    )
    
    # Test 集不 shuffle,方便对齐
    return DataLoader(dataset, batch_size=batch_size, shuffle=(split=="train"), collate_fn=custom_collate)