goodmodeler commited on
Commit
686c304
·
1 Parent(s): 696ae63

add train lora

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. train_lora.py +51 -0
README.md CHANGED
@@ -49,6 +49,7 @@ torch.cuda.reset_peak_memory_stats()
49
 
50
  7/12
51
  # 1 Fine‑tune image model LoRA+QLoRA
 
52
  python train_lora.py
53
 
54
  # 2 SFT 语言模型
 
49
 
50
  7/12
51
  # 1 Fine‑tune image model LoRA+QLoRA
52
+ accelerate launch --deepspeed_config_file=ds_config_zero3.json train_lora.py
53
  python train_lora.py
54
 
55
  # 2 SFT 语言模型
train_lora.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_lora.py – QLoRA + DeepSpeed DreamBooth Fine-Tuning (Stable Diffusion)
2
+
3
+ import os, argparse, torch
4
+ from diffusers import StableDiffusionPipeline, DDPMScheduler
5
+ from diffusers import DreamBoothLoraTrainer
6
+ from peft import LoraConfig
7
+ from accelerate import Accelerator
8
+
9
+ parser = argparse.ArgumentParser()
10
+ parser.add_argument("--data", default="./nyc_ads_dataset") # 你的训练图片目录
11
+ args = parser.parse_args()
12
+
13
+ # LoRA 配置(兼容 QLoRA)
14
+ lora_cfg = LoraConfig(
15
+ r=8,
16
+ lora_alpha=32,
17
+ lora_dropout=0.05,
18
+ target_modules=["q_proj", "v_proj"]
19
+ )
20
+
21
+ # 4-bit 量化加载 SD-1.5
22
+ pipe = StableDiffusionPipeline.from_pretrained(
23
+ "runwayml/stable-diffusion-v1-5",
24
+ torch_dtype=torch.float16,
25
+ load_in_4bit=True,
26
+ quantization_config={
27
+ "bnb_4bit_compute_dtype": torch.float16,
28
+ "bnb_4bit_use_double_quant": True,
29
+ "bnb_4bit_quant_type": "nf4"
30
+ },
31
+ )
32
+
33
+ # DreamBooth LoRA Trainer
34
+ trainer = DreamBoothLoraTrainer(
35
+ instance_data_root=args.data,
36
+ instance_prompt="a photo of an urbanad nyc",
37
+ lora_config=lora_cfg,
38
+ output_dir="./nyc-ad-model",
39
+ max_train_steps=400,
40
+ train_batch_size=1,
41
+ gradient_checkpointing=True,
42
+ )
43
+
44
+ # DeepSpeed ZeRO-3 加速 / 显存拆分
45
+ accelerator = Accelerator(
46
+ mixed_precision="fp16",
47
+ deepspeed_config="./ds_config_zero3.json" # 需提前放置
48
+ )
49
+
50
+ # 开始训练
51
+ trainer.train(accelerator)