Spaces:
Runtime error
Runtime error
Update train.py
Browse files
train.py
CHANGED
|
@@ -121,6 +121,7 @@ class DistillationTrainer(Trainer):
|
|
| 121 |
return loss
|
| 122 |
|
| 123 |
# 训练参数
|
|
|
|
| 124 |
training_args = TrainingArguments(
|
| 125 |
output_dir="/tmp/distilled_model",
|
| 126 |
num_train_epochs=3,
|
|
@@ -132,10 +133,12 @@ training_args = TrainingArguments(
|
|
| 132 |
save_strategy="epoch",
|
| 133 |
remove_unused_columns=False,
|
| 134 |
gradient_checkpointing=True,
|
| 135 |
-
use_cache=False, # ✅ **修正 `use_cache=True` 的冲突**
|
| 136 |
fp16=torch.cuda.is_available()
|
| 137 |
)
|
| 138 |
|
|
|
|
|
|
|
|
|
|
| 139 |
# 初始化 Trainer
|
| 140 |
trainer = DistillationTrainer(
|
| 141 |
teacher=teacher,
|
|
|
|
| 121 |
return loss
|
| 122 |
|
| 123 |
# 训练参数
|
| 124 |
+
# ✅ 移除 `use_cache` 选项
|
| 125 |
training_args = TrainingArguments(
|
| 126 |
output_dir="/tmp/distilled_model",
|
| 127 |
num_train_epochs=3,
|
|
|
|
| 133 |
save_strategy="epoch",
|
| 134 |
remove_unused_columns=False,
|
| 135 |
gradient_checkpointing=True,
|
|
|
|
| 136 |
fp16=torch.cuda.is_available()
|
| 137 |
)
|
| 138 |
|
| 139 |
+
# ✅ **手动禁用 `use_cache`**
|
| 140 |
+
student.config.use_cache = False # 🔥 这样就不会影响 `TrainingArguments`,但依然禁用了 `use_cache`
|
| 141 |
+
|
| 142 |
# 初始化 Trainer
|
| 143 |
trainer = DistillationTrainer(
|
| 144 |
teacher=teacher,
|