Spaces:
Runtime error
Runtime error
| # train_lora.py – QLoRA + DeepSpeed DreamBooth Fine-Tuning (Stable Diffusion) | |
| import os, argparse, torch | |
| from diffusers import StableDiffusionPipeline, DDPMScheduler | |
| from diffusers import DreamBoothLoraTrainer | |
| from peft import LoraConfig | |
| from accelerate import Accelerator | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--data", default="./nyc_ads_dataset") # 你的训练图片目录 | |
| args = parser.parse_args() | |
| # LoRA 配置(兼容 QLoRA) | |
| lora_cfg = LoraConfig( | |
| r=8, | |
| lora_alpha=32, | |
| lora_dropout=0.05, | |
| target_modules=["q_proj", "v_proj"] | |
| ) | |
| # 4-bit 量化加载 SD-1.5 | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16, | |
| load_in_4bit=True, | |
| quantization_config={ | |
| "bnb_4bit_compute_dtype": torch.float16, | |
| "bnb_4bit_use_double_quant": True, | |
| "bnb_4bit_quant_type": "nf4" | |
| }, | |
| ) | |
| # DreamBooth LoRA Trainer | |
| trainer = DreamBoothLoraTrainer( | |
| instance_data_root=args.data, | |
| instance_prompt="a photo of an urbanad nyc", | |
| lora_config=lora_cfg, | |
| output_dir="./nyc-ad-model", | |
| max_train_steps=400, | |
| train_batch_size=1, | |
| gradient_checkpointing=True, | |
| ) | |
| # DeepSpeed ZeRO-3 加速 / 显存拆分 | |
| accelerator = Accelerator( | |
| mixed_precision="fp16", | |
| deepspeed_config="./ds_config_zero3.json" # 需提前放置 | |
| ) | |
| # 开始训练 | |
| trainer.train(accelerator) |