Snow2222 commited on
Commit
515c81d
·
verified ·
1 Parent(s): 4faf472

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +6 -3
train.py CHANGED
@@ -49,12 +49,15 @@ with open('data.json', 'r', encoding='utf-8') as f:
49
  # 假设 data 是一个列表,则使用 Dataset.from_list
50
  dataset = Dataset.from_list(data)
51
 
52
- # 定义预处理,对 'instruction' 文本进行 tokenize
53
  def preprocess_data(example):
54
- return tokenizer(example['instruction'], truncation=True, padding="max_length", max_length=128)
 
 
 
55
 
56
  # 对数据集进行预处理,并移除原始文本列(此处同时移除了 'instruction' 与 'output',仅保留 tokenize 后的输入)
57
- dataset = dataset.map(preprocess_data, batched=True, remove_columns=["instruction", "output"])
58
 
59
  # 自定义知识蒸馏 Trainer,结合交叉熵损失(hard target)和 KL 散度损失(soft target)
60
  class DistillationTrainer(Trainer):
 
49
  # 假设 data 是一个列表,则使用 Dataset.from_list
50
  dataset = Dataset.from_list(data)
51
 
52
+ # 预处理数
53
  def preprocess_data(example):
54
+ return {
55
+ "instruction": example["instruction"],
56
+ "output": example["output"]
57
+ }
58
 
59
  # 对数据集进行预处理,并移除原始文本列(此处同时移除了 'instruction' 与 'output',仅保留 tokenize 后的输入)
60
+ dataset = dataset.map(preprocess_data, batched=True)
61
 
62
  # 自定义知识蒸馏 Trainer,结合交叉熵损失(hard target)和 KL 散度损失(soft target)
63
  class DistillationTrainer(Trainer):