Update app.py
Browse files
app.py
CHANGED
|
@@ -6,9 +6,9 @@ from safetensors.torch import load_file
|
|
| 6 |
|
| 7 |
# Define a function to generate the image
|
| 8 |
def generate_image(prompt, num_inference_steps):
|
| 9 |
-
base = "stable-diffusion"
|
| 10 |
-
repo = "
|
| 11 |
-
ckpt = "
|
| 12 |
|
| 13 |
# Load model.
|
| 14 |
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|
|
|
|
| 6 |
|
| 7 |
# Define a function to generate the image
|
| 8 |
def generate_image(prompt, num_inference_steps):
|
| 9 |
+
base = "stabilityai/stable-diffusion-xl-base-1.0"
|
| 10 |
+
repo = "ByteDance/SDXL-Lightning"
|
| 11 |
+
ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting!
|
| 12 |
|
| 13 |
# Load model.
|
| 14 |
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16)
|