Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| import spaces | |
| import os | |
| from PIL import Image, ImageFilter | |
| from typing import List, Tuple | |
| SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", "0") == "1" | |
| # Constants | |
| base = "stabilityai/stable-diffusion-xl-base-1.0" | |
| repo = "ByteDance/SDXL-Lightning" | |
| checkpoints = { | |
| "1-Step" : ["sdxl_lightning_1step_unet_x0.safetensors", 1], | |
| "2-Step" : ["sdxl_lightning_2step_unet.safetensors", 2], | |
| "4-Step" : ["sdxl_lightning_4step_unet.safetensors", 4], | |
| "8-Step" : ["sdxl_lightning_8step_unet.safetensors", 8], | |
| } | |
| aspect_ratios = { | |
| "21:9": (21, 9), | |
| "2:1": (2, 1), | |
| "16:9": (16, 9), | |
| "5:4": (5, 4), | |
| "4:3": (4, 3), | |
| "3:2": (3, 2), | |
| "1:1": (1, 1), | |
| } | |
| # Function to calculate resolution | |
| def calculate_resolution(aspect_ratio, mode='landscape', total_pixels=1024*1024, divisibility=8): | |
| if aspect_ratio not in aspect_ratios: | |
| raise ValueError(f"Invalid aspect ratio: {aspect_ratio}") | |
| width_multiplier, height_multiplier = aspect_ratios[aspect_ratio] | |
| ratio = width_multiplier / height_multiplier | |
| if mode == 'portrait': | |
| # Swap the ratio for portrait mode | |
| ratio = 1 / ratio | |
| height = int((total_pixels / ratio) ** 0.5) | |
| height -= height % divisibility | |
| width = int(height * ratio) | |
| width -= width % divisibility | |
| while width * height > total_pixels: | |
| height -= divisibility | |
| width = int(height * ratio) | |
| width -= width % divisibility | |
| return width, height | |
| # Example prompts with ckpt, aspect, and mode | |
| examples = [ | |
| {"prompt": "A futuristic cityscape at sunset", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "16:9", "mode": "landscape"}, | |
| {"prompt": "pair of shoes made of dried fruit skins, 3d render, bright colours, clean composition, beautiful artwork, logo", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"}, | |
| {"prompt": "A portrait of a robot in the style of Renaissance art", "negative_prompt": "Ugly", "ckpt": "2-Step", "aspect": "1:1", "mode": "portrait"}, | |
| {"prompt": "full body of alien shaped like woman, big golden eyes, mars planet, photo, digital art, fantasy", "negative_prompt": "Ugly", "ckpt": "4-Step", "aspect": "1:1", "mode": "portrait"}, | |
| {"prompt": "A serene landscape with mountains and a river", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "3:2", "mode": "landscape"}, | |
| {"prompt": "post-apocalyptic wasteland, the most delicate beautiful flower with green leaves growing from dust and rubble, vibrant colours, cinematic", "negative_prompt": "Ugly", "ckpt": "8-Step", "aspect": "16:9", "mode": "landscape"} | |
| ] | |
| # Define a function to set the example inputs | |
| def set_example(selected_prompt): | |
| # Find the example that matches the selected prompt | |
| for example in examples: | |
| if example["prompt"] == selected_prompt: | |
| return example["prompt"], example["negative_prompt"], example["ckpt"], example["aspect"], example["mode"] | |
| return None, None, None, None, None # Default values if not found | |
| # Check if CUDA is available (GPU support), and set the appropriate device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load the pipeline for the specified device | |
| # For GPU, use torch_dtype=torch.float16 for better performance | |
| if device == "cuda": | |
| pipe = StableDiffusionXLPipeline.from_pretrained(base, torch_dtype=torch.float16, variant="fp16").to(device) | |
| else: | |
| pipe = StableDiffusionXLPipeline.from_pretrained(base).to(device) | |
| if SAFETY_CHECKER: | |
| from safety_checker import StableDiffusionSafetyChecker | |
| from transformers import CLIPFeatureExtractor | |
| safety_checker = StableDiffusionSafetyChecker.from_pretrained( | |
| "CompVis/stable-diffusion-safety-checker" | |
| ).to(device) | |
| feature_extractor = CLIPFeatureExtractor.from_pretrained( | |
| "openai/clip-vit-base-patch32" | |
| ) | |
| def check_nsfw_images( | |
| images: List[Image.Image] | |
| ) -> Tuple[List[Image.Image], List[bool]]: | |
| # Assuming feature_extractor and safety_checker are defined and initialized elsewhere | |
| # Convert PIL Images to the format expected by the feature extractor | |
| # This often involves converting them to tensors, but the exact method | |
| # depends on the feature_extractor's requirements | |
| safety_checker_inputs = [feature_extractor(image, return_tensors="pt").to("cuda") for image in images] | |
| # Get NSFW concepts for each image | |
| has_nsfw_concepts = [safety_checker( | |
| images=[image], | |
| clip_input=safety_checker_input.pixel_values.to("cuda") | |
| ) for image, safety_checker_input in zip(images, safety_checker_inputs)] | |
| # Flatten the has_nsfw_concepts list if it's nested | |
| has_nsfw_concepts = [item for sublist in has_nsfw_concepts for item in sublist] | |
| return images, has_nsfw_concepts | |
| # Function | |
| def generate_image(prompt, negative_prompt, ckpt, aspect_ratio, mode): | |
| width, height = calculate_resolution(aspect_ratio, mode) # Calculate resolution based on the aspect ratio | |
| checkpoint = checkpoints[ckpt][0] | |
| num_inference_steps = checkpoints[ckpt][1] | |
| if num_inference_steps==1: | |
| # Ensure sampler uses "trailing" timesteps and "sample" prediction type for 1-step inference. | |
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", prediction_type="sample") | |
| else: | |
| # Ensure sampler uses "trailing" timesteps. | |
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing") | |
| pipe.unet.load_state_dict(load_file(hf_hub_download(repo, checkpoint), device=device)) | |
| results = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps, guidance_scale=0, width=width, height=height ) | |
| if SAFETY_CHECKER: | |
| images, has_nsfw_concepts = check_nsfw_images(results.images) | |
| if any(has_nsfw_concepts): | |
| gr.Warning("NSFW content detected.") | |
| # Apply a blur filter to the first image in the results | |
| blurred_image = images[0].filter(ImageFilter.GaussianBlur(16)) # Adjust the radius as needed | |
| return blurred_image | |
| return images[0] | |
| return results.images[0] | |
| # Gradio Interface | |
| description = """ | |
| SDXL-Lightning ByteDance model demo. Link to model: https://huggingface.co/ByteDance/SDXL-Lightning | |
| """ | |
| with gr.Blocks(css="style.css") as demo: | |
| gr.HTML("<h1><center>Text-to-Image with SDXL-Lightning ⚡</center></h1>") | |
| gr.Markdown(description) | |
| with gr.Group(): | |
| with gr.Row(): | |
| prompt = gr.Textbox(label='Enter you image prompt:', scale=8) | |
| with gr.Row(): | |
| negative_prompt = gr.Textbox(label='Optional negative prompt:', scale=8) | |
| with gr.Row(): | |
| ckpt = gr.Dropdown(label='Select inference steps',choices=['1-Step', '2-Step', '4-Step', '8-Step'], value='4-Step', interactive=True) | |
| aspect = gr.Dropdown(label='Aspect Ratio', choices=list(aspect_ratios.keys()), value='1:1', interactive=True) | |
| mode = gr.Dropdown(label='Mode', choices=['landscape', 'portrait'], value='landscape') # Mode as a dropdown | |
| submit = gr.Button(scale=1, variant='primary') | |
| img = gr.Image(label='SDXL-Lightning Generated Image') | |
| prompt.submit(fn=generate_image, | |
| inputs=[prompt, negative_prompt, ckpt, aspect, mode], | |
| outputs=img, | |
| ) | |
| submit.click(fn=generate_image, | |
| inputs=[prompt, negative_prompt, ckpt, aspect, mode], | |
| outputs=img, | |
| ) | |
| # Dropdown for selecting examples | |
| example_dropdown = gr.Dropdown(label='Select an Example', choices=[e["prompt"] for e in examples]) | |
| example_dropdown.change(fn=set_example, inputs=example_dropdown, outputs=[prompt, negative_prompt, ckpt, aspect, mode]) | |
| demo.queue().launch() | |