#!/usr/bin/env python3 """ Gradio Application for Stable Diffusion Author: Shilpaj Bhalerao Date: Feb 26, 2025 """ import os import torch import gradio as gr import spaces from tqdm.auto import tqdm import numpy as np from PIL import Image from utils import ( load_models, clear_gpu_memory, set_timesteps, latents_to_pil, vignette_loss, get_concept_embedding, load_concept_library, image_grid ) from diffusers import StableDiffusionPipeline # Set device device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" if device == "mps": os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" # Load model with proper caching @spaces.GPU def load_model(): return StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None ) @spaces.GPU @gr.Cache() def get_pipeline(): pipe = load_model() return pipe.to("cuda") # Load concept library concept_embeds, concept_tokens = load_concept_library(get_pipeline()) # Define art style concepts art_concepts = { "sketch_painting": "a sketch painting, pencil drawing, hand-drawn illustration", "oil_painting": "an oil painting, textured canvas, painterly technique", "watercolor": "a watercolor painting, fluid, soft edges", "digital_art": "digital art, computer generated, precise details", "comic_book": "comic book style, ink outlines, cel shading" } def generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept_style=None, concept_strength=0.5, height=512, width=512): """ Generate latents using the UNet model Args: prompt (str): Text prompt seed (int): Random seed num_inference_steps (int): Number of denoising steps guidance_scale (float): Scale for classifier-free guidance vignette_loss_scale (float): Scale for vignette loss concept_style (str, optional): Style concept to use concept_strength (float): Strength of concept influence (0.0-1.0) height (int): Image height width (int): Image width Returns: torch.Tensor: Generated latents """ # Set the seed generator = torch.manual_seed(seed) batch_size = 1 # Clear GPU memory clear_gpu_memory() # Get concept embedding if specified concept_embedding = None if concept_style: if concept_style in concept_tokens: # Use pre-trained concept embedding concept_embedding = concept_embeds[concept_style].unsqueeze(0).to(device) elif concept_style in art_concepts: # Generate concept embedding from text description concept_text = art_concepts[concept_style] concept_embedding = get_concept_embedding(concept_text, get_pipeline().tokenizer, get_pipeline().text_encoder, device) # Prep text text_input = get_pipeline().tokenizer([prompt], padding="max_length", max_length=get_pipeline().tokenizer.model_max_length, truncation=True, return_tensors="pt") with torch.inference_mode(): text_embeddings = get_pipeline().text_encoder(text_input.input_ids.to(device))[0] # Apply concept embedding influence if provided if concept_embedding is not None and concept_strength > 0: # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3: concept_embedding = concept_embedding.unsqueeze(0) # Create weighted blend between original text embedding and concept if text_embeddings.shape == concept_embedding.shape: text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding # Unconditional embedding for classifier-free guidance max_length = text_input.input_ids.shape[-1] uncond_input = get_pipeline().tokenizer( [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" ) with torch.inference_mode(): uncond_embeddings = get_pipeline().text_encoder(uncond_input.input_ids.to(device))[0] text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) # Prep Scheduler set_timesteps(get_pipeline().scheduler, num_inference_steps) # Prep latents latents = torch.randn( (batch_size, get_pipeline().unet.in_channels, height // 8, width // 8), generator=generator, ) latents = latents.to(device) latents = latents * get_pipeline().scheduler.init_noise_sigma # Loop through diffusion process for i, t in tqdm(enumerate(get_pipeline().scheduler.timesteps), total=len(get_pipeline().scheduler.timesteps)): # Expand latents for classifier-free guidance latent_model_input = torch.cat([latents] * 2) sigma = get_pipeline().scheduler.sigmas[i] latent_model_input = get_pipeline().scheduler.scale_model_input(latent_model_input, t) # Predict the noise residual with torch.inference_mode(): noise_pred = get_pipeline().unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] # Perform classifier-free guidance noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) # Apply additional guidance with vignette loss if vignette_loss_scale > 0 and i % 5 == 0: # Requires grad on the latents latents = latents.detach().requires_grad_() # Get the predicted x0 latents_x0 = latents - sigma * noise_pred # Decode to image space denoised_images = get_pipeline().vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1) # Calculate loss loss = vignette_loss(denoised_images) * vignette_loss_scale # Get gradient cond_grad = torch.autograd.grad(loss, latents)[0] # Modify the latents based on this gradient latents = latents.detach() - cond_grad * sigma**2 # Step with scheduler latents = get_pipeline().scheduler.step(noise_pred, t, latents).prev_sample return latents @spaces.GPU def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5, height=512, width=512): """ Generate an image using Stable Diffusion Args: prompt (str): Text prompt seed (int): Random seed num_inference_steps (int): Number of denoising steps guidance_scale (float): Scale for classifier-free guidance vignette_loss_scale (float): Scale for vignette loss concept_style (str): Style concept to use concept_strength (float): Strength of concept influence (0.0-1.0) height (int): Image height width (int): Image width Returns: PIL.Image: Generated image """ # Handle "none" concept style if concept_style == "none": concept_style = None # Generate latents latents = generate_latents( prompt=prompt, seed=seed, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, vignette_loss_scale=vignette_loss_scale, concept_style=concept_style, concept_strength=concept_strength, height=height, width=width ) # Convert latents to image images = latents_to_pil(latents, get_pipeline().vae) return images[0] def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, vignette_loss_scale=0.0, concept_strength=0.5): """ Generate a grid of images with different style concepts Args: prompt (str): Text prompt seed (int): Random seed num_inference_steps (int): Number of denoising steps guidance_scale (float): Scale for classifier-free guidance vignette_loss_scale (float): Scale for vignette loss concept_strength (float): Strength of concept influence (0.0-1.0) Returns: PIL.Image: Grid of generated images """ # List of styles to use styles = list(art_concepts.keys()) # Generate images for each style images = [] labels = [] for i, style in enumerate(styles): # Generate image with this style latents = generate_latents( prompt=prompt, seed=seed + i, # Use different seeds for variety num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, vignette_loss_scale=vignette_loss_scale, concept_style=style, concept_strength=concept_strength ) # Convert latents to image style_images = latents_to_pil(latents, get_pipeline().vae) images.append(style_images[0]) labels.append(style) # Create grid grid = image_grid(images, 1, len(styles), labels) return grid # Define Gradio interface @spaces.GPU(enable_queue=False) def create_demo(): with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo: gr.Markdown("# Guided Stable Diffusion with Styles") with gr.Tab("Single Image Generation"): with gr.Row(): with gr.Column(): prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair") seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5) vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0) # Combine SD concept library tokens and art concept descriptions all_styles = ["none"] + concept_tokens + list(art_concepts.keys()) concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none") concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) generate_btn = gr.Button("Generate Image") with gr.Column(): output_image = gr.Image(label="Generated Image", type="pil") with gr.Tab("Style Grid"): with gr.Row(): with gr.Column(): grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park") grid_seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Base Seed", value=42) grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5) grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0) grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) grid_generate_btn = gr.Button("Generate Style Grid") with gr.Column(): output_grid = gr.Image(label="Style Grid", type="pil") # Set up event handlers generate_btn.click( generate_image, inputs=[prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept_style, concept_strength], outputs=output_image ) grid_generate_btn.click( generate_style_grid, inputs=[grid_prompt, grid_seed, grid_num_inference_steps, grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength], outputs=output_grid ) return demo # Launch the app if __name__ == "__main__": demo = create_demo() demo.launch(debug=False, show_error=True, server_name="0.0.0.0", server_port=7860, cache_examples=True)