RianLi commited on
Commit
61b0657
·
verified ·
1 Parent(s): 23426c8

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +37 -0
  2. data.json +22 -0
  3. fine_tune.py +82 -0
  4. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+ import gradio as gr
3
+
4
+ def train():
5
+ # 安装依赖
6
+ process = subprocess.Popen(
7
+ ['pip', 'install', '-r', 'requirements.txt'],
8
+ stdout=subprocess.PIPE,
9
+ stderr=subprocess.STDOUT,
10
+ text=True
11
+ )
12
+ for line in iter(process.stdout.readline, ''):
13
+ yield line
14
+ process.wait()
15
+
16
+ yield "---依赖安装完成,开始训练---"
17
+
18
+ # 运行训练脚本
19
+ process = subprocess.Popen(
20
+ ['python3', 'fine_tune.py'],
21
+ stdout=subprocess.PIPE,
22
+ stderr=subprocess.STDOUT,
23
+ text=True
24
+ )
25
+ for line in iter(process.stdout.readline, ''):
26
+ yield line
27
+ process.wait()
28
+
29
+ yield "---训练完成!---"
30
+
31
+ with gr.Blocks() as demo:
32
+ gr.Markdown("点击按钮开始微调")
33
+ output = gr.Textbox(label="训练日志", lines=20)
34
+ train_button = gr.Button("开始微调")
35
+ train_button.click(fn=train, inputs=[], outputs=output)
36
+
37
+ demo.launch()
data.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "根据以下信息,生成一个用户JSON对象。",
4
+ "input": "用户ID是123,用户名是alice,邮箱是alice@example.com",
5
+ "output": "{\"user_id\": 123, \"username\": \"alice\", \"email\": \"alice@example.com\"}"
6
+ },
7
+ {
8
+ "instruction": "根据以下信息,生成一个用户JSON对象。",
9
+ "input": "用户ID是456,用户名是bob,邮箱是bob@example.com",
10
+ "output": "{\"user_id\": 456, \"username\": \"bob\", \"email\": \"bob@example.com\"}"
11
+ },
12
+ {
13
+ "instruction": "根据以下信息,生成一个用户JSON对象。",
14
+ "input": "用户ID是789,用户名是charlie,邮箱是charlie@example.com",
15
+ "output": "{\"user_id\": 789, \"username\": \"charlie\", \"email\": \"charlie@example.com\"}"
16
+ },
17
+ {
18
+ "instruction": "根据以下信息,生成一个用户JSON对象。",
19
+ "input": "用户ID是101,用户名是dave,邮箱是dave@example.com",
20
+ "output": "{\"user_id\": 101, \"username\": \"dave\", \"email\": \"dave@example.com\"}"
21
+ }
22
+ ]
fine_tune.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from datasets import load_dataset
3
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
5
+ from trl import SFTTrainer
6
+
7
+ # 1. 加载模型和分词器
8
+ model_name = "NousResearch/Llama-2-7b-chat-hf"
9
+
10
+ # BitsAndBytesConfig for QLoRA
11
+ bnb_config = BitsAndBytesConfig(
12
+ load_in_4bit=True,
13
+ bnb_4bit_use_double_quant=True,
14
+ bnb_4bit_quant_type="nf4",
15
+ bnb_4bit_compute_dtype=torch.bfloat16
16
+ )
17
+
18
+ # Load model
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ quantization_config=bnb_config,
22
+ device_map="auto",
23
+ trust_remote_code=True
24
+ )
25
+ model.config.use_cache = False # for training
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
28
+ tokenizer.pad_token = tokenizer.eos_token # set pad token
29
+
30
+ # 2. 加载并准备数据集
31
+ def formatting_prompts_func(example):
32
+ output_texts = []
33
+ for i in range(len(example['instruction'])):
34
+ text = f"### Instruction:\n{example['instruction'][i]}\n\n### Input:\n{example['input'][i]}\n\n### Response:\n{example['output'][i]}"
35
+ output_texts.append(text)
36
+ return output_texts
37
+
38
+ dataset = load_dataset("json", data_files="data.json", split="train")
39
+
40
+ # 3. 配置LoRA参数
41
+ lora_config = LoraConfig(
42
+ r=8, # Rank
43
+ lora_alpha=32,
44
+ lora_dropout=0.1,
45
+ bias="none",
46
+ task_type="CAUSAL_LM",
47
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Llama-2 specific modules
48
+ )
49
+
50
+ # 4. 创建PEFT模型
51
+ model = prepare_model_for_kbit_training(model)
52
+ model = get_peft_model(model, lora_config)
53
+
54
+ # 5. 配置训练参数
55
+ output_dir = "./llama-2-7b-chat-json"
56
+ training_args = TrainingArguments(
57
+ output_dir=output_dir,
58
+ per_device_train_batch_size=4,
59
+ gradient_accumulation_steps=4,
60
+ learning_rate=2e-4,
61
+ logging_steps=10,
62
+ max_steps=100, # for demo
63
+ save_strategy="epoch",
64
+ # num_train_epochs=1, # use max_steps for demo
65
+ )
66
+
67
+ # 6. 创建Trainer并开始训练
68
+ trainer = SFTTrainer(
69
+ model=model,
70
+ train_dataset=dataset,
71
+ args=training_args,
72
+ peft_config=lora_config,
73
+ formatting_func=formatting_prompts_func,
74
+ max_seq_length=512,
75
+ )
76
+
77
+ trainer.train()
78
+
79
+ # 7. 保存模型
80
+ print("Saving LoRA adapter...")
81
+ trainer.save_model(output_dir)
82
+ print(f"LoRA adapter saved to {output_dir}")
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ peft
4
+ trl
5
+ bitsandbytes
6
+ datasets
7
+ gradio