Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
|
@@ -49,12 +49,15 @@ with open('data.json', 'r', encoding='utf-8') as f:
|
|
| 49 |
# 假设 data 是一个列表,则使用 Dataset.from_list
|
| 50 |
dataset = Dataset.from_list(data)
|
| 51 |
|
| 52 |
-
#
|
| 53 |
def preprocess_data(example):
|
| 54 |
-
return
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
# 对数据集进行预处理,并移除原始文本列(此处同时移除了 'instruction' 与 'output',仅保留 tokenize 后的输入)
|
| 57 |
-
dataset = dataset.map(preprocess_data, batched=True
|
| 58 |
|
| 59 |
# 自定义知识蒸馏 Trainer,结合交叉熵损失(hard target)和 KL 散度损失(soft target)
|
| 60 |
class DistillationTrainer(Trainer):
|
|
|
|
| 49 |
# 假设 data 是一个列表,则使用 Dataset.from_list
|
| 50 |
dataset = Dataset.from_list(data)
|
| 51 |
|
| 52 |
+
# 预处理数据
|
| 53 |
def preprocess_data(example):
|
| 54 |
+
return {
|
| 55 |
+
"instruction": example["instruction"],
|
| 56 |
+
"output": example["output"]
|
| 57 |
+
}
|
| 58 |
|
| 59 |
# 对数据集进行预处理,并移除原始文本列(此处同时移除了 'instruction' 与 'output',仅保留 tokenize 后的输入)
|
| 60 |
+
dataset = dataset.map(preprocess_data, batched=True)
|
| 61 |
|
| 62 |
# 自定义知识蒸馏 Trainer,结合交叉熵损失(hard target)和 KL 散度损失(soft target)
|
| 63 |
class DistillationTrainer(Trainer):
|