Snow2222 commited on
Commit
9f0184e
·
verified ·
1 Parent(s): 566a3f6

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +3 -2
train.py CHANGED
@@ -17,7 +17,7 @@ else:
17
 
18
  # 定义教师模型与学生模型
19
  teacher_model_name = "Qwen/Qwen1.5-7B-Chat" # 教师模型(较大模型)
20
- student_model_name = "gpt2" # 学生模型(较小模型可更换)
21
 
22
  # 加载教师模型(仅用于生成软标签,不参与梯度计算)
23
  teacher = AutoModelForCausalLM.from_pretrained(
@@ -74,7 +74,7 @@ class DistillationTrainer(Trainer):
74
  super().__init__(*args, **kwargs)
75
  self.teacher = teacher # ✅ 传入教师模型
76
 
77
- def compute_loss(self, model, inputs, return_outputs=False):
78
  labels = inputs["input_ids"]
79
 
80
  # ✅ 计算学生模型的输出
@@ -115,6 +115,7 @@ training_args = TrainingArguments(
115
  logging_steps=100,
116
  save_strategy="epoch",
117
  remove_unused_columns=False, # ✅ 关键设置,确保 Trainer 不删除未识别的列
 
118
  fp16=True if torch.cuda.is_available() else False
119
  )
120
 
 
17
 
18
  # 定义教师模型与学生模型
19
  teacher_model_name = "Qwen/Qwen1.5-7B-Chat" # 教师模型(较大模型)
20
+ student_model_name = "distilgpt2" # 学生模型,建议用 distilgpt2 替代 gpt2
21
 
22
  # 加载教师模型(仅用于生成软标签,不参与梯度计算)
23
  teacher = AutoModelForCausalLM.from_pretrained(
 
74
  super().__init__(*args, **kwargs)
75
  self.teacher = teacher # ✅ 传入教师模型
76
 
77
+ def compute_loss(self, model, inputs, return_outputs=False): # ❌ 去掉 num_items_in_batch
78
  labels = inputs["input_ids"]
79
 
80
  # ✅ 计算学生模型的输出
 
115
  logging_steps=100,
116
  save_strategy="epoch",
117
  remove_unused_columns=False, # ✅ 关键设置,确保 Trainer 不删除未识别的列
118
+ gradient_checkpointing=True, # ✅ 允许梯度检查点,节省显存
119
  fp16=True if torch.cuda.is_available() else False
120
  )
121