RianLi commited on
Commit
2b9d7ab
·
verified ·
1 Parent(s): cb6e241

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +123 -25
  2. fine_tune.py +24 -26
app.py CHANGED
@@ -1,37 +1,135 @@
1
- import subprocess
2
  import gradio as gr
 
 
 
3
 
4
- def train():
5
- # 安装依赖
6
- process = subprocess.Popen(
7
- ['pip', 'install', '-r', 'requirements.txt'],
8
- stdout=subprocess.PIPE,
9
- stderr=subprocess.STDOUT,
10
- text=True
11
- )
12
- for line in iter(process.stdout.readline, ''):
13
- yield line
14
- process.wait()
15
-
16
- yield "---依赖安装完成,开始训练---"
17
 
18
- # 运行训练脚本
 
19
  process = subprocess.Popen(
20
  ['python3', 'fine_tune.py'],
21
  stdout=subprocess.PIPE,
22
  stderr=subprocess.STDOUT,
23
- text=True
 
 
24
  )
25
- for line in iter(process.stdout.readline, ''):
26
- yield line
 
 
27
  process.wait()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- yield "---训练完成!---"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- with gr.Blocks() as demo:
32
- gr.Markdown("点击按钮开始微调")
33
- output = gr.Textbox(label="训练日志", lines=20)
34
- train_button = gr.Button("开始微调")
35
- train_button.click(fn=train, inputs=[], outputs=output)
 
 
 
 
 
 
36
 
37
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import subprocess
3
+ import threading
4
+ import time
5
 
6
+ def get_md_content(file_path):
7
+ try:
8
+ with open(file_path, 'r', encoding='utf-8') as f:
9
+ return f.read()
10
+ except FileNotFoundError:
11
+ return f"Error: {file_path} not found."
12
+ except Exception as e:
13
+ return f"An error occurred: {e}"
 
 
 
 
 
14
 
15
+ def run_script():
16
+ """Function to run the fine-tuning script and stream output."""
17
  process = subprocess.Popen(
18
  ['python3', 'fine_tune.py'],
19
  stdout=subprocess.PIPE,
20
  stderr=subprocess.STDOUT,
21
+ text=True,
22
+ bufsize=1,
23
+ universal_newlines=True
24
  )
25
+ output = ""
26
+ for line in process.stdout:
27
+ output += line
28
+ yield output
29
  process.wait()
30
+
31
+ # JavaScript to find and render Mermaid diagrams
32
+ js_script = """
33
+ () => {
34
+ function initMermaidAndConvert() {
35
+ // Wait for mermaid to be available
36
+ if (typeof mermaid === 'undefined') {
37
+ console.log('Mermaid not loaded yet, retrying...');
38
+ setTimeout(initMermaidAndConvert, 100);
39
+ return;
40
+ }
41
+
42
+ console.log('Mermaid loaded successfully');
43
+ // Initialize mermaid
44
+ mermaid.initialize({
45
+ startOnLoad: false,
46
+ theme: 'default',
47
+ securityLevel: 'loose'
48
+ });
49
+
50
+ function convertMermaidCodeBlocks() {
51
+ console.log('Converting Mermaid code blocks...');
52
+ let processedCount = 0;
53
+
54
+ // Look for pre blocks that contain mermaid syntax
55
+ document.querySelectorAll('pre').forEach((pre, index) => {
56
+ const code = pre.querySelector('code');
57
+ if (code && !pre.classList.contains('mermaid-processed')) {
58
+ const text = code.textContent.trim();
59
+ // Check if it contains mermaid syntax
60
+ const isMermaid = text.includes('graph ') ||
61
+ text.includes('flowchart ') ||
62
+ text.includes('subgraph ') ||
63
+ text.startsWith('graph') ||
64
+ text.startsWith('flowchart') ||
65
+ text.includes('classDiagram') ||
66
+ text.includes('sequenceDiagram');
67
+
68
+ if (isMermaid) {
69
+ console.log(`Found Mermaid diagram ${processedCount + 1}:`, text.substring(0, 50) + '...');
70
+ pre.classList.add('mermaid');
71
+ pre.classList.add('mermaid-processed');
72
+ pre.textContent = text;
73
+ processedCount++;
74
+ }
75
+ }
76
+ });
77
+
78
+ console.log(`Processed ${processedCount} Mermaid diagrams`);
79
+ // Run Mermaid
80
+ try {
81
+ mermaid.run();
82
+ } catch (e) {
83
+ console.log('Mermaid rendering error:', e);
84
+ }
85
+ }
86
+
87
+ // Use a MutationObserver to re-run the conversion when Gradio updates the page
88
+ const observer = new MutationObserver((mutations) => {
89
+ // A simple debounce to avoid excessive re-renders
90
+ clearTimeout(window.mermaidTimeout);
91
+ window.mermaidTimeout = setTimeout(convertMermaidCodeBlocks, 100);
92
+ });
93
+ observer.observe(document.body, { childList: true, subtree: true });
94
+
95
+ // Initial run
96
+ convertMermaidCodeBlocks();
97
+ }
98
 
99
+ // Start the initialization
100
+ initMermaidAndConvert();
101
+ }
102
+ """
103
+
104
+ # HTML to include the Mermaid.js library
105
+ head_script = '<script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>'
106
+
107
+ with gr.Blocks(theme=gr.themes.Soft(), head=head_script, js=js_script) as demo:
108
+ gr.Markdown("# 微调技术分享")
109
+
110
+ with gr.Tabs():
111
+ with gr.TabItem("分享大纲"):
112
+ gr.Markdown(get_md_content("outline.md"))
113
+
114
+ with gr.TabItem("核心技术概览"):
115
+ gr.Markdown(get_md_content("presentation.md"))
116
+
117
+ with gr.TabItem("LoRA & QLoRA 深度解析"):
118
+ gr.Markdown(get_md_content("lora_qlora_deep_dive.md"))
119
+
120
+
121
 
122
+ with gr.TabItem("动手实战:模型微调"):
123
+ with gr.Row():
124
+ start_button = gr.Button("开始微调", variant="primary")
125
+
126
+ log_output = gr.Textbox(
127
+ label="训练日志",
128
+ interactive=False,
129
+ lines=20,
130
+ show_copy_button=True
131
+ )
132
+ start_button.click(fn=run_script, outputs=log_output)
133
 
134
+ if __name__ == "__main__":
135
+ demo.launch()
fine_tune.py CHANGED
@@ -4,20 +4,15 @@ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
5
  from trl import SFTTrainer
6
 
7
- # 1. 加载模型和分词器
8
- model_name = "NousResearch/Llama-2-7b-chat-hf"
9
-
10
- # BitsAndBytesConfig for QLoRA
11
- bnb_config = BitsAndBytesConfig(
12
- load_in_4bit=True,
13
- bnb_4bit_quant_type="nf4",
14
- bnb_4bit_compute_dtype=torch.float16,
15
- )
16
 
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_name,
19
- # quantization_config=bnb_config, # Deactivated for CPU
20
- # device_map="auto" # Deactivated for CPU
21
  )
22
  model.config.use_cache = False
23
 
@@ -34,31 +29,34 @@ def formatting_prompts_func(example):
34
 
35
  dataset = load_dataset("json", data_files="data.json", split="train")
36
 
37
- # 3. 配置LoRA参数
38
  lora_config = LoraConfig(
39
  r=8, # Rank
40
  lora_alpha=32,
41
  lora_dropout=0.1,
42
  bias="none",
43
  task_type="CAUSAL_LM",
44
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Llama-2 specific modules
45
  )
46
 
47
- # 4. 创建PEFT模型
48
- model = prepare_model_for_kbit_training(model)
49
  model = get_peft_model(model, lora_config)
50
 
51
- # 5. 配置训练参数
52
- output_dir = "./llama-2-7b-chat-json"
53
  training_args = TrainingArguments(
54
  output_dir=output_dir,
55
- per_device_train_batch_size=4,
56
- gradient_accumulation_steps=4,
57
- learning_rate=2e-4,
58
- logging_steps=10,
59
- max_steps=100, # for demo
60
- save_strategy="epoch",
61
- # num_train_epochs=1, # use max_steps for demo
 
 
 
62
  )
63
 
64
  # 6. 创建Trainer并开始训练
@@ -74,6 +72,6 @@ trainer = SFTTrainer(
74
  trainer.train()
75
 
76
  # 7. 保存模型
77
- print("Saving LoRA adapter...")
78
  trainer.save_model(output_dir)
79
- print(f"LoRA adapter saved to {output_dir}")
 
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TrainingArguments
5
  from trl import SFTTrainer
6
 
7
+ # 1. 加载模型和分词器 (CPU优化版本)
8
+ # 使用更小的模型以适配CPU环境
9
+ model_name = "microsoft/DialoGPT-small" # 更小的模型,适合CPU训练
 
 
 
 
 
 
10
 
11
+ # CPU环境下不需要量化配置
12
  model = AutoModelForCausalLM.from_pretrained(
13
  model_name,
14
+ torch_dtype=torch.float32, # CPU使用float32
15
+ low_cpu_mem_usage=True, # 优化CPU内存使用
16
  )
17
  model.config.use_cache = False
18
 
 
29
 
30
  dataset = load_dataset("json", data_files="data.json", split="train")
31
 
32
+ # 3. 配置LoRA参数 (适配DialoGPT)
33
  lora_config = LoraConfig(
34
  r=8, # Rank
35
  lora_alpha=32,
36
  lora_dropout=0.1,
37
  bias="none",
38
  task_type="CAUSAL_LM",
39
+ target_modules=["c_attn", "c_proj"], # DialoGPT/GPT-2 架构的注意力模块
40
  )
41
 
42
+ # 4. 创建PEFT模型 (CPU版本)
43
+ # CPU环境下不需要量化准备
44
  model = get_peft_model(model, lora_config)
45
 
46
+ # 5. 配置训练参数 (CPU优化)
47
+ output_dir = "./dialogpt-small-lora"
48
  training_args = TrainingArguments(
49
  output_dir=output_dir,
50
+ per_device_train_batch_size=1, # CPU环境使用更小的批次
51
+ gradient_accumulation_steps=8, # 增加梯度累积以补偿小批次
52
+ learning_rate=5e-4, # 稍微提高学习率
53
+ logging_steps=5,
54
+ max_steps=50, # 减少训练步数用于演示
55
+ save_strategy="steps",
56
+ save_steps=25,
57
+ dataloader_num_workers=0, # CPU环境下设为0
58
+ fp16=False, # CPU不支持fp16
59
+ report_to=None, # 禁用wandb等报告
60
  )
61
 
62
  # 6. 创建Trainer并开始训练
 
72
  trainer.train()
73
 
74
  # 7. 保存模型
75
+ print("Saving DialoGPT LoRA adapter...")
76
  trainer.save_model(output_dir)
77
+ print(f"DialoGPT LoRA adapter saved to {output_dir}")