Snow2222 commited on
Commit
17a809a
·
verified ·
1 Parent(s): eb81ebd

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +2 -2
train.py CHANGED
@@ -68,7 +68,7 @@ def preprocess_data(example):
68
  # 预处理数据集
69
  dataset = dataset.map(preprocess_data, batched=True)
70
 
71
- # ✅ 自定义 `DistillationTrainer`,覆盖 `training_step()` 以防止 `num_items_in_batch` 传递
72
  class DistillationTrainer(Trainer):
73
  def __init__(self, teacher, *args, **kwargs):
74
  super().__init__(*args, **kwargs)
@@ -104,7 +104,7 @@ class DistillationTrainer(Trainer):
104
 
105
  return (loss, outputs_student) if return_outputs else loss
106
 
107
- def training_step(self, model, inputs):
108
  """✅ 关键修复点:覆盖 `training_step()`,防止 `num_items_in_batch` 传递"""
109
  model.train()
110
  inputs = self._prepare_inputs(inputs)
 
68
  # 预处理数据集
69
  dataset = dataset.map(preprocess_data, batched=True)
70
 
71
+ # ✅ 修正 training_step() 参数问题
72
  class DistillationTrainer(Trainer):
73
  def __init__(self, teacher, *args, **kwargs):
74
  super().__init__(*args, **kwargs)
 
104
 
105
  return (loss, outputs_student) if return_outputs else loss
106
 
107
+ def training_step(self, model, inputs, *args, **kwargs): # ✅ 修正:添加 *args, **kwargs 以兼容 Trainer
108
  """✅ 关键修复点:覆盖 `training_step()`,防止 `num_items_in_batch` 传递"""
109
  model.train()
110
  inputs = self._prepare_inputs(inputs)