| import spaces |
| from diffusers import ( |
| StableDiffusionXLPipeline, |
| EulerDiscreteScheduler, |
| UNet2DConditionModel, |
| AutoencoderTiny, |
| ) |
| import torch |
| import os |
| from huggingface_hub import hf_hub_download |
|
|
|
|
| from PIL import Image |
| import gradio as gr |
| import time |
| from safetensors.torch import load_file |
| import time |
| import tempfile |
| from pathlib import Path |
|
|
| |
| BASE = "stabilityai/stable-diffusion-xl-base-1.0" |
| REPO = "ByteDance/SDXL-Lightning" |
| |
| CHECKPOINT = "sdxl_lightning_2step_unet.safetensors" |
| taesd_model = "madebyollin/taesdxl" |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| SFAST_COMPILE = os.environ.get("SFAST_COMPILE", "0") == "1" |
| SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1" |
| USE_TAESD = os.environ.get("USE_TAESD", "0") == "1" |
|
|
| |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| torch_device = device |
| torch_dtype = torch.float16 |
|
|
| print(f"SAFETY_CHECKER: {SAFETY_CHECKER}") |
| print(f"SFAST_COMPILE: {SFAST_COMPILE}") |
| print(f"USE_TAESD: {USE_TAESD}") |
| print(f"device: {device}") |
|
|
|
|
| unet = UNet2DConditionModel.from_config(BASE, subfolder="unet").to( |
| "cuda", torch.float16 |
| ) |
| unet.load_state_dict(load_file(hf_hub_download(REPO, CHECKPOINT), device="cuda")) |
| pipe = StableDiffusionXLPipeline.from_pretrained( |
| BASE, unet=unet, torch_dtype=torch.float16, variant="fp16", safety_checker=False |
| ).to("cuda") |
| unet = unet.to(dtype=torch.float16) |
|
|
| if USE_TAESD: |
| pipe.vae = AutoencoderTiny.from_pretrained( |
| taesd_model, torch_dtype=torch_dtype, use_safetensors=True |
| ).to(device) |
|
|
|
|
| |
| pipe.scheduler = EulerDiscreteScheduler.from_config( |
| pipe.scheduler.config, timestep_spacing="trailing" |
| ) |
| pipe.set_progress_bar_config(disable=True) |
| if SAFETY_CHECKER: |
| from safety_checker import StableDiffusionSafetyChecker |
| from transformers import CLIPFeatureExtractor |
|
|
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( |
| "CompVis/stable-diffusion-safety-checker" |
| ).to(device) |
| feature_extractor = CLIPFeatureExtractor.from_pretrained( |
| "openai/clip-vit-base-patch32" |
| ) |
|
|
| def check_nsfw_images( |
| images: list[Image.Image], |
| ) -> tuple[list[Image.Image], list[bool]]: |
| safety_checker_input = feature_extractor(images, return_tensors="pt").to(device) |
| has_nsfw_concepts = safety_checker( |
| images=[images], |
| clip_input=safety_checker_input.pixel_values.to(torch_device), |
| ) |
|
|
| return images, has_nsfw_concepts |
|
|
|
|
| if SFAST_COMPILE: |
| from sfast.compilers.diffusion_pipeline_compiler import compile, CompilationConfig |
|
|
| |
| config = CompilationConfig.Default() |
| try: |
| import xformers |
|
|
| config.enable_xformers = True |
| except ImportError: |
| print("xformers not installed, skip") |
| try: |
| import triton |
|
|
| config.enable_triton = True |
| except ImportError: |
| print("Triton not installed, skip") |
| |
| |
| |
| config.enable_cuda_graph = True |
|
|
| pipe = compile(pipe, config) |
|
|
|
|
| @spaces.GPU |
| def predict(prompt, seed=1231231): |
| generator = torch.manual_seed(seed) |
| last_time = time.time() |
| results = pipe( |
| prompt=prompt, |
| generator=generator, |
| num_inference_steps=2, |
| guidance_scale=0.0, |
| |
| |
| output_type="pil", |
| ) |
| print(f"Pipe took {time.time() - last_time} seconds") |
| if SAFETY_CHECKER: |
| images, has_nsfw_concepts = check_nsfw_images(results.images) |
| if any(has_nsfw_concepts): |
| gr.Warning("NSFW content detected.") |
| return Image.new("RGB", (512, 512)) |
| image = results.images[0] |
| with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as tmpfile: |
| image.save(tmpfile, "JPEG", quality=80, optimize=True, progressive=True) |
| return Path(tmpfile.name) |
|
|
|
|
| css = """ |
| #container{ |
| margin: 0 auto; |
| max-width: 40rem; |
| } |
| #intro{ |
| max-width: 100%; |
| margin: 0 auto; |
| } |
| """ |
| with gr.Blocks(css=css) as demo: |
| with gr.Column(elem_id="container"): |
| gr.Markdown( |
| """ |
| # SDXL-Lightning- Text To Image 2-Steps |
| **Model**: https://huggingface.co/ByteDance/SDXL-Lightning |
| """, |
| elem_id="intro", |
| ) |
| with gr.Row(): |
| with gr.Row(): |
| prompt = gr.Textbox( |
| placeholder="Insert your prompt here:", scale=5, container=False |
| ) |
| generate_bt = gr.Button("Generate", scale=1) |
|
|
| image = gr.Image(type="filepath") |
| with gr.Accordion("Advanced options", open=False): |
| seed = gr.Slider( |
| randomize=True, minimum=0, maximum=12013012031030, label="Seed", step=1 |
| ) |
| with gr.Accordion("Run with diffusers"): |
| gr.Markdown( |
| """## Running SDXL-Lightning with `diffusers` |
| ```py |
| import torch |
| from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel, EulerDiscreteScheduler |
| from huggingface_hub import hf_hub_download |
| from safetensors.torch import load_file |
| |
| base = "stabilityai/stable-diffusion-xl-base-1.0" |
| repo = "ByteDance/SDXL-Lightning" |
| ckpt = "sdxl_lightning_2step_unet.safetensors" # Use the correct ckpt for your step setting! |
| |
| # Load model. |
| unet = UNet2DConditionModel.from_config(base, subfolder="unet").to("cuda", torch.float16) |
| unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device="cuda")) |
| pipe = StableDiffusionXLPipeline.from_pretrained(base, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") |
| |
| # Ensure sampler uses "trailing" timesteps. |
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") |
| |
| # Ensure using the same inference steps as the loaded model and CFG set to 0. |
| pipe("A girl smiling", num_inference_steps=2, guidance_scale=0).images[0].save("output.png") |
| ``` |
| """ |
| ) |
|
|
| inputs = [prompt, seed] |
| outputs = [image] |
| generate_bt.click( |
| fn=predict, inputs=inputs, outputs=outputs, show_progress=False |
| ) |
| prompt.input(fn=predict, inputs=inputs, outputs=outputs, trigger_mode="always_last", show_progress=False) |
| seed.change(fn=predict, inputs=inputs, outputs=outputs, show_progress=False) |
|
|
| demo.queue() |
| demo.launch() |
|
|
|
|