wilsonHuggingFace commited on
Commit
f42fbdd
·
verified ·
1 Parent(s): 48ed12e

Create SDXL_Lighting.py

Browse files
Files changed (1) hide show
  1. SDXL_Lighting.py +30 -0
SDXL_Lighting.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler
3
+ from huggingface_hub import hf_hub_download
4
+ from safetensors.torch import load_file
5
+
6
+ base = "stabilityai/stable-diffusion-xl-base-1.0"
7
+ repo = "ByteDance/SDXL-Lightning"
8
+ ckpt = "sdxl_lightning_4step_unet.safetensors" # Use the correct ckpt for your step setting!
9
+
10
+ ## download
11
+ # git clone https://hf-mirror.com/stabilityai/stable-diffusion-xl-base-1.0
12
+ # wget -c hf-mirror.com/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/unet/diffusion_pytorch_model.fp16.safetensors
13
+ # wget -c https://hf-mirror.com/ByteDance/SDXL-Lightning/sdxl_lightning_4step_unet.safetensors
14
+ # base = xxx/stable-diffusion-xl-base-1.0
15
+ # ckpt = SDXL-Lightning/sdxl_lightning_4step_unet.safetensors
16
+
17
+ # Load model.
18
+ # unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
19
+ # unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda"))
20
+ # pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
21
+
22
+ unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
23
+ unet.load_state_dict(load_file(ckpt), device="cuda"))
24
+ pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
25
+
26
+ # Ensure sampler uses "trailing" timesteps.
27
+ pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
28
+
29
+ # Ensure using the same inference steps as the loaded model and CFG set to 0.
30
+ pipe("A girl smiling", num_inference_steps=4, guidance_scale=0).images[0].save("output.png")