Snow2222 commited on
Commit
e131b6c
·
verified ·
1 Parent(s): 5edcbdb

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +58 -0
train.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
3
+ from peft import LoraConfig, get_peft_model
4
+ from datasets import load_dataset
5
+
6
+ # 加载DeepSeek R1模型
7
+ model_name = "DeepSeek/R1" # 你可以根据实际选择不同的路径
8
+
9
+ # 加载模型和分词器
10
+ model = AutoModelForCausalLM.from_pretrained(model_name)
11
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
12
+
13
+ # 配置 LoRA 微调
14
+ lora_config = LoraConfig(
15
+ r=8, # LoRA参数
16
+ lora_alpha=16,
17
+ lora_dropout=0.1,
18
+ bias="none",
19
+ )
20
+
21
+ # 获取LoRA微调模型
22
+ model = get_peft_model(model, lora_config)
23
+
24
+ # 准备数据集
25
+ data = [
26
+ {"instruction": "粉丝通跨店版的费用是多少?", "output": "粉丝通跨店版按月付费,500元/月,仅提供增值税普通电子发票。"},
27
+ {"instruction": "如何充值粉丝通软件的红包?", "output": "商家可以灵活充值红包,每个红包最低0.1元,具体总额根据拉新目标决定。"},
28
+ {"instruction": "红包的扣费机制是怎样的?", "output": "红包在用户实际使用后才会扣款,未使用到期会自动退回商家公户。"},
29
+ # 你可以继续添加数据...
30
+ ]
31
+
32
+ # 转换数据为 Hugging Face 数据集格式
33
+ train_data = [{"input_ids": tokenizer.encode(d["instruction"], truncation=True, padding="max_length"), "labels": tokenizer.encode(d["output"], truncation=True, padding="max_length")} for d in data]
34
+
35
+ train_dataset = load_dataset('json', data_files={'train': train_data})
36
+
37
+ # 设置训练参数
38
+ training_args = TrainingArguments(
39
+ output_dir='./results',
40
+ evaluation_strategy="epoch",
41
+ learning_rate=5e-5,
42
+ per_device_train_batch_size=2,
43
+ per_device_eval_batch_size=2,
44
+ num_train_epochs=3,
45
+ weight_decay=0.01,
46
+ save_steps=10_000,
47
+ save_total_limit=2,
48
+ )
49
+
50
+ # 设置 Trainer
51
+ trainer = Trainer(
52
+ model=model,
53
+ args=training_args,
54
+ train_dataset=train_dataset["train"],
55
+ )
56
+
57
+ # 开始训练
58
+ trainer.train()