Snow2222 commited on
Commit
5f81635
·
verified ·
1 Parent(s): 93bf8ad

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +4 -1
train.py CHANGED
@@ -121,6 +121,7 @@ class DistillationTrainer(Trainer):
121
  return loss
122
 
123
  # 训练参数
 
124
  training_args = TrainingArguments(
125
  output_dir="/tmp/distilled_model",
126
  num_train_epochs=3,
@@ -132,10 +133,12 @@ training_args = TrainingArguments(
132
  save_strategy="epoch",
133
  remove_unused_columns=False,
134
  gradient_checkpointing=True,
135
- use_cache=False, # ✅ **修正 `use_cache=True` 的冲突**
136
  fp16=torch.cuda.is_available()
137
  )
138
 
 
 
 
139
  # 初始化 Trainer
140
  trainer = DistillationTrainer(
141
  teacher=teacher,
 
121
  return loss
122
 
123
  # 训练参数
124
+ # ✅ 移除 `use_cache` 选项
125
  training_args = TrainingArguments(
126
  output_dir="/tmp/distilled_model",
127
  num_train_epochs=3,
 
133
  save_strategy="epoch",
134
  remove_unused_columns=False,
135
  gradient_checkpointing=True,
 
136
  fp16=torch.cuda.is_available()
137
  )
138
 
139
+ # ✅ **手动禁用 `use_cache`**
140
+ student.config.use_cache = False # 🔥 这样就不会影响 `TrainingArguments`,但依然禁用了 `use_cache`
141
+
142
  # 初始化 Trainer
143
  trainer = DistillationTrainer(
144
  teacher=teacher,