Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
|
@@ -68,7 +68,7 @@ def preprocess_data(example):
|
|
| 68 |
# 预处理数据集
|
| 69 |
dataset = dataset.map(preprocess_data, batched=True)
|
| 70 |
|
| 71 |
-
# ✅
|
| 72 |
class DistillationTrainer(Trainer):
|
| 73 |
def __init__(self, teacher, *args, **kwargs):
|
| 74 |
super().__init__(*args, **kwargs)
|
|
@@ -104,7 +104,7 @@ class DistillationTrainer(Trainer):
|
|
| 104 |
|
| 105 |
return (loss, outputs_student) if return_outputs else loss
|
| 106 |
|
| 107 |
-
def training_step(self, model, inputs):
|
| 108 |
"""✅ 关键修复点:覆盖 `training_step()`,防止 `num_items_in_batch` 传递"""
|
| 109 |
model.train()
|
| 110 |
inputs = self._prepare_inputs(inputs)
|
|
|
|
| 68 |
# 预处理数据集
|
| 69 |
dataset = dataset.map(preprocess_data, batched=True)
|
| 70 |
|
| 71 |
+
# ✅ 修正 training_step() 参数问题
|
| 72 |
class DistillationTrainer(Trainer):
|
| 73 |
def __init__(self, teacher, *args, **kwargs):
|
| 74 |
super().__init__(*args, **kwargs)
|
|
|
|
| 104 |
|
| 105 |
return (loss, outputs_student) if return_outputs else loss
|
| 106 |
|
| 107 |
+
def training_step(self, model, inputs, *args, **kwargs): # ✅ 修正:添加 *args, **kwargs 以兼容 Trainer
|
| 108 |
"""✅ 关键修复点:覆盖 `training_step()`,防止 `num_items_in_batch` 传递"""
|
| 109 |
model.train()
|
| 110 |
inputs = self._prepare_inputs(inputs)
|