Snow2222 commited on
Commit
bb7b0e9
·
verified ·
1 Parent(s): 17a809a

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +17 -13
train.py CHANGED
@@ -15,6 +15,9 @@ if hf_token:
15
  else:
16
  raise ValueError("Hugging Face token 未设置")
17
 
 
 
 
18
  # 定义教师模型与学生模型
19
  teacher_model_name = "Qwen/Qwen1.5-7B-Chat" # 教师模型(较大模型)
20
  student_model_name = "distilgpt2" # ✅ 建议用 distilgpt2
@@ -24,7 +27,7 @@ teacher = AutoModelForCausalLM.from_pretrained(
24
  teacher_model_name,
25
  trust_remote_code=True,
26
  token=hf_token
27
- )
28
  teacher.eval() # 固定教师模型,不训练
29
 
30
  # 加载学生模型及 Tokenizer
@@ -32,7 +35,7 @@ student = AutoModelForCausalLM.from_pretrained(
32
  student_model_name,
33
  trust_remote_code=True,
34
  token=hf_token
35
- )
36
  tokenizer = AutoTokenizer.from_pretrained(
37
  student_model_name,
38
  trust_remote_code=True,
@@ -60,19 +63,19 @@ def preprocess_data(example):
60
  labels = tokenizer(example["output"], truncation=True, padding="max_length", max_length=128)
61
 
62
  return {
63
- "input_ids": inputs["input_ids"],
64
- "attention_mask": inputs["attention_mask"],
65
- "labels": labels["input_ids"]
66
  }
67
 
68
  # 预处理数据集
69
  dataset = dataset.map(preprocess_data, batched=True)
70
 
71
- # ✅ 修正 training_step() 参数问题
72
  class DistillationTrainer(Trainer):
73
  def __init__(self, teacher, *args, **kwargs):
74
  super().__init__(*args, **kwargs)
75
- self.teacher = teacher # ✅ 传入教师模型
76
 
77
  def compute_loss(self, model, inputs, return_outputs=False):
78
  labels = inputs["input_ids"]
@@ -83,7 +86,8 @@ class DistillationTrainer(Trainer):
83
 
84
  # ✅ 使用教师模型生成软标签(冻结教师参数)
85
  with torch.no_grad():
86
- outputs_teacher = self.teacher(**inputs)
 
87
  logits_teacher = outputs_teacher.logits
88
 
89
  temperature = 2.0
@@ -104,10 +108,10 @@ class DistillationTrainer(Trainer):
104
 
105
  return (loss, outputs_student) if return_outputs else loss
106
 
107
- def training_step(self, model, inputs, *args, **kwargs): # ✅ 修正:添加 *args, **kwargs 以兼容 Trainer
108
- """✅ 关键修复点:覆盖 `training_step()`,防止 `num_items_in_batch` 传递"""
109
  model.train()
110
- inputs = self._prepare_inputs(inputs)
111
  loss = self.compute_loss(model, inputs) # ✅ 直接调用,不传递 `num_items_in_batch`
112
  return loss
113
 
@@ -123,7 +127,7 @@ training_args = TrainingArguments(
123
  save_strategy="epoch",
124
  remove_unused_columns=False, # ✅ 关键设置,确保 Trainer 不删除未识别的列
125
  gradient_checkpointing=True, # ✅ 允许梯度检查点,节省显存
126
- fp16=True if torch.cuda.is_available() else False
127
  )
128
 
129
  # 初始化 Trainer
@@ -140,4 +144,4 @@ trainer.train()
140
 
141
  # 保存模型到 Hugging Face
142
  student.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
143
- tokenizer.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
 
15
  else:
16
  raise ValueError("Hugging Face token 未设置")
17
 
18
+ # ✅ 确保所有设备一致
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
+
21
  # 定义教师模型与学生模型
22
  teacher_model_name = "Qwen/Qwen1.5-7B-Chat" # 教师模型(较大模型)
23
  student_model_name = "distilgpt2" # ✅ 建议用 distilgpt2
 
27
  teacher_model_name,
28
  trust_remote_code=True,
29
  token=hf_token
30
+ ).to(device) # ✅ 强制放到 GPU 或 CPU
31
  teacher.eval() # 固定教师模型,不训练
32
 
33
  # 加载学生模型及 Tokenizer
 
35
  student_model_name,
36
  trust_remote_code=True,
37
  token=hf_token
38
+ ).to(device) # ✅ 也放到 GPU 或 CPU
39
  tokenizer = AutoTokenizer.from_pretrained(
40
  student_model_name,
41
  trust_remote_code=True,
 
63
  labels = tokenizer(example["output"], truncation=True, padding="max_length", max_length=128)
64
 
65
  return {
66
+ "input_ids": torch.tensor(inputs["input_ids"]).to(device), # ✅ 强制放到 GPU 或 CPU
67
+ "attention_mask": torch.tensor(inputs["attention_mask"]).to(device),
68
+ "labels": torch.tensor(labels["input_ids"]).to(device)
69
  }
70
 
71
  # 预处理数据集
72
  dataset = dataset.map(preprocess_data, batched=True)
73
 
74
+ # ✅ 修正 training_step() 设备不匹配问题
75
  class DistillationTrainer(Trainer):
76
  def __init__(self, teacher, *args, **kwargs):
77
  super().__init__(*args, **kwargs)
78
+ self.teacher = teacher.to(device) # ✅ 确保 teacher 在 GPU
79
 
80
  def compute_loss(self, model, inputs, return_outputs=False):
81
  labels = inputs["input_ids"]
 
86
 
87
  # ✅ 使用教师模型生成软标签(冻结教师参数)
88
  with torch.no_grad():
89
+ inputs_on_device = {k: v.to(device) for k, v in inputs.items()} # ✅ 确保 inputs 在 GPU
90
+ outputs_teacher = self.teacher(**inputs_on_device)
91
  logits_teacher = outputs_teacher.logits
92
 
93
  temperature = 2.0
 
108
 
109
  return (loss, outputs_student) if return_outputs else loss
110
 
111
+ def training_step(self, model, inputs, *args, **kwargs): # ✅ 兼容 Trainer 额外参数
112
+ """✅ 关键修复点:确保所有输入和模型都在 GPU"""
113
  model.train()
114
+ inputs = {k: v.to(device) for k, v in self._prepare_inputs(inputs).items()} # ✅ 确保 inputs 在 GPU
115
  loss = self.compute_loss(model, inputs) # ✅ 直接调用,不传递 `num_items_in_batch`
116
  return loss
117
 
 
127
  save_strategy="epoch",
128
  remove_unused_columns=False, # ✅ 关键设置,确保 Trainer 不删除未识别的列
129
  gradient_checkpointing=True, # ✅ 允许梯度检查点,节省显存
130
+ fp16=torch.cuda.is_available() # 自动判断是否使用 FP16
131
  )
132
 
133
  # 初始化 Trainer
 
144
 
145
  # 保存模型到 Hugging Face
146
  student.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)
147
+ tokenizer.push_to_hub("Snow2222/fst-nnn", use_auth_token=hf_token)