Snow2222 commited on
Commit
eb81ebd
·
verified ·
1 Parent(s): 9f0184e

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +10 -3
train.py CHANGED
@@ -17,7 +17,7 @@ else:
17
 
18
  # 定义教师模型与学生模型
19
  teacher_model_name = "Qwen/Qwen1.5-7B-Chat" # 教师模型(较大模型)
20
- student_model_name = "distilgpt2" # ✅ 学生模型,建议用 distilgpt2 替代 gpt2
21
 
22
  # 加载教师模型(仅用于生成软标签,不参与梯度计算)
23
  teacher = AutoModelForCausalLM.from_pretrained(
@@ -68,13 +68,13 @@ def preprocess_data(example):
68
  # 预处理数据集
69
  dataset = dataset.map(preprocess_data, batched=True)
70
 
71
- # 自定义知识蒸馏 Trainer
72
  class DistillationTrainer(Trainer):
73
  def __init__(self, teacher, *args, **kwargs):
74
  super().__init__(*args, **kwargs)
75
  self.teacher = teacher # ✅ 传入教师模型
76
 
77
- def compute_loss(self, model, inputs, return_outputs=False): # ❌ 去掉 num_items_in_batch
78
  labels = inputs["input_ids"]
79
 
80
  # ✅ 计算学生模型的输出
@@ -104,6 +104,13 @@ class DistillationTrainer(Trainer):
104
 
105
  return (loss, outputs_student) if return_outputs else loss
106
 
 
 
 
 
 
 
 
107
  # 训练参数
108
  training_args = TrainingArguments(
109
  output_dir="/tmp/distilled_model",
 
17
 
18
  # 定义教师模型与学生模型
19
  teacher_model_name = "Qwen/Qwen1.5-7B-Chat" # 教师模型(较大模型)
20
+ student_model_name = "distilgpt2" # ✅ 建议用 distilgpt2
21
 
22
  # 加载教师模型(仅用于生成软标签,不参与梯度计算)
23
  teacher = AutoModelForCausalLM.from_pretrained(
 
68
  # 预处理数据集
69
  dataset = dataset.map(preprocess_data, batched=True)
70
 
71
+ # 自定义 `DistillationTrainer`,覆盖 `training_step()` 以防止 `num_items_in_batch` 传递
72
  class DistillationTrainer(Trainer):
73
  def __init__(self, teacher, *args, **kwargs):
74
  super().__init__(*args, **kwargs)
75
  self.teacher = teacher # ✅ 传入教师模型
76
 
77
+ def compute_loss(self, model, inputs, return_outputs=False):
78
  labels = inputs["input_ids"]
79
 
80
  # ✅ 计算学生模型的输出
 
104
 
105
  return (loss, outputs_student) if return_outputs else loss
106
 
107
+ def training_step(self, model, inputs):
108
+ """✅ 关键修复点:覆盖 `training_step()`,防止 `num_items_in_batch` 传递"""
109
+ model.train()
110
+ inputs = self._prepare_inputs(inputs)
111
+ loss = self.compute_loss(model, inputs) # ✅ 直接调用,不传递 `num_items_in_batch`
112
+ return loss
113
+
114
  # 训练参数
115
  training_args = TrainingArguments(
116
  output_dir="/tmp/distilled_model",