Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Gradio Application for Stable Diffusion | |
| Author: Shilpaj Bhalerao | |
| Date: Feb 26, 2025 | |
| """ | |
| import os | |
| import torch | |
| import gradio as gr | |
| import spaces | |
| from tqdm.auto import tqdm | |
| import numpy as np | |
| from PIL import Image | |
| from utils import ( | |
| load_models, clear_gpu_memory, set_timesteps, latents_to_pil, | |
| vignette_loss, get_concept_embedding, load_concept_library, image_grid | |
| ) | |
| from diffusers import StableDiffusionPipeline | |
| # Set device | |
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" | |
| if device == "mps": | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1" | |
| # Load model with proper caching | |
| def load_model(): | |
| return StableDiffusionPipeline.from_pretrained( | |
| "runwayml/stable-diffusion-v1-5", | |
| torch_dtype=torch.float16, | |
| safety_checker=None | |
| ) | |
| def get_pipeline(): | |
| pipe = load_model() | |
| return pipe.to("cuda") | |
| # Load concept library | |
| concept_embeds, concept_tokens = load_concept_library(get_pipeline()) | |
| # Define art style concepts | |
| art_concepts = { | |
| "sketch_painting": "a sketch painting, pencil drawing, hand-drawn illustration", | |
| "oil_painting": "an oil painting, textured canvas, painterly technique", | |
| "watercolor": "a watercolor painting, fluid, soft edges", | |
| "digital_art": "digital art, computer generated, precise details", | |
| "comic_book": "comic book style, ink outlines, cel shading" | |
| } | |
| def generate_latents(prompt, seed, num_inference_steps, guidance_scale, | |
| vignette_loss_scale, concept_style=None, concept_strength=0.5, | |
| height=512, width=512): | |
| """ | |
| Generate latents using the UNet model | |
| Args: | |
| prompt (str): Text prompt | |
| seed (int): Random seed | |
| num_inference_steps (int): Number of denoising steps | |
| guidance_scale (float): Scale for classifier-free guidance | |
| vignette_loss_scale (float): Scale for vignette loss | |
| concept_style (str, optional): Style concept to use | |
| concept_strength (float): Strength of concept influence (0.0-1.0) | |
| height (int): Image height | |
| width (int): Image width | |
| Returns: | |
| torch.Tensor: Generated latents | |
| """ | |
| # Set the seed | |
| generator = torch.manual_seed(seed) | |
| batch_size = 1 | |
| # Clear GPU memory | |
| clear_gpu_memory() | |
| # Get concept embedding if specified | |
| concept_embedding = None | |
| if concept_style: | |
| if concept_style in concept_tokens: | |
| # Use pre-trained concept embedding | |
| concept_embedding = concept_embeds[concept_style].unsqueeze(0).to(device) | |
| elif concept_style in art_concepts: | |
| # Generate concept embedding from text description | |
| concept_text = art_concepts[concept_style] | |
| concept_embedding = get_concept_embedding(concept_text, get_pipeline().tokenizer, get_pipeline().text_encoder, device) | |
| # Prep text | |
| text_input = get_pipeline().tokenizer([prompt], padding="max_length", max_length=get_pipeline().tokenizer.model_max_length, | |
| truncation=True, return_tensors="pt") | |
| with torch.inference_mode(): | |
| text_embeddings = get_pipeline().text_encoder(text_input.input_ids.to(device))[0] | |
| # Apply concept embedding influence if provided | |
| if concept_embedding is not None and concept_strength > 0: | |
| # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed | |
| if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3: | |
| concept_embedding = concept_embedding.unsqueeze(0) | |
| # Create weighted blend between original text embedding and concept | |
| if text_embeddings.shape == concept_embedding.shape: | |
| text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding | |
| # Unconditional embedding for classifier-free guidance | |
| max_length = text_input.input_ids.shape[-1] | |
| uncond_input = get_pipeline().tokenizer( | |
| [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| with torch.inference_mode(): | |
| uncond_embeddings = get_pipeline().text_encoder(uncond_input.input_ids.to(device))[0] | |
| text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) | |
| # Prep Scheduler | |
| set_timesteps(get_pipeline().scheduler, num_inference_steps) | |
| # Prep latents | |
| latents = torch.randn( | |
| (batch_size, get_pipeline().unet.in_channels, height // 8, width // 8), | |
| generator=generator, | |
| ) | |
| latents = latents.to(device) | |
| latents = latents * get_pipeline().scheduler.init_noise_sigma | |
| # Loop through diffusion process | |
| for i, t in tqdm(enumerate(get_pipeline().scheduler.timesteps), total=len(get_pipeline().scheduler.timesteps)): | |
| # Expand latents for classifier-free guidance | |
| latent_model_input = torch.cat([latents] * 2) | |
| sigma = get_pipeline().scheduler.sigmas[i] | |
| latent_model_input = get_pipeline().scheduler.scale_model_input(latent_model_input, t) | |
| # Predict the noise residual | |
| with torch.inference_mode(): | |
| noise_pred = get_pipeline().unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"] | |
| # Perform classifier-free guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| # Apply additional guidance with vignette loss | |
| if vignette_loss_scale > 0 and i % 5 == 0: | |
| # Requires grad on the latents | |
| latents = latents.detach().requires_grad_() | |
| # Get the predicted x0 | |
| latents_x0 = latents - sigma * noise_pred | |
| # Decode to image space | |
| denoised_images = get_pipeline().vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1) | |
| # Calculate loss | |
| loss = vignette_loss(denoised_images) * vignette_loss_scale | |
| # Get gradient | |
| cond_grad = torch.autograd.grad(loss, latents)[0] | |
| # Modify the latents based on this gradient | |
| latents = latents.detach() - cond_grad * sigma**2 | |
| # Step with scheduler | |
| latents = get_pipeline().scheduler.step(noise_pred, t, latents).prev_sample | |
| return latents | |
| def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, | |
| vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5, | |
| height=512, width=512): | |
| """ | |
| Generate an image using Stable Diffusion | |
| Args: | |
| prompt (str): Text prompt | |
| seed (int): Random seed | |
| num_inference_steps (int): Number of denoising steps | |
| guidance_scale (float): Scale for classifier-free guidance | |
| vignette_loss_scale (float): Scale for vignette loss | |
| concept_style (str): Style concept to use | |
| concept_strength (float): Strength of concept influence (0.0-1.0) | |
| height (int): Image height | |
| width (int): Image width | |
| Returns: | |
| PIL.Image: Generated image | |
| """ | |
| # Handle "none" concept style | |
| if concept_style == "none": | |
| concept_style = None | |
| # Generate latents | |
| latents = generate_latents( | |
| prompt=prompt, | |
| seed=seed, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| vignette_loss_scale=vignette_loss_scale, | |
| concept_style=concept_style, | |
| concept_strength=concept_strength, | |
| height=height, | |
| width=width | |
| ) | |
| # Convert latents to image | |
| images = latents_to_pil(latents, get_pipeline().vae) | |
| return images[0] | |
| def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5, | |
| vignette_loss_scale=0.0, concept_strength=0.5): | |
| """ | |
| Generate a grid of images with different style concepts | |
| Args: | |
| prompt (str): Text prompt | |
| seed (int): Random seed | |
| num_inference_steps (int): Number of denoising steps | |
| guidance_scale (float): Scale for classifier-free guidance | |
| vignette_loss_scale (float): Scale for vignette loss | |
| concept_strength (float): Strength of concept influence (0.0-1.0) | |
| Returns: | |
| PIL.Image: Grid of generated images | |
| """ | |
| # List of styles to use | |
| styles = list(art_concepts.keys()) | |
| # Generate images for each style | |
| images = [] | |
| labels = [] | |
| for i, style in enumerate(styles): | |
| # Generate image with this style | |
| latents = generate_latents( | |
| prompt=prompt, | |
| seed=seed + i, # Use different seeds for variety | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| vignette_loss_scale=vignette_loss_scale, | |
| concept_style=style, | |
| concept_strength=concept_strength | |
| ) | |
| # Convert latents to image | |
| style_images = latents_to_pil(latents, get_pipeline().vae) | |
| images.append(style_images[0]) | |
| labels.append(style) | |
| # Create grid | |
| grid = image_grid(images, 1, len(styles), labels) | |
| return grid | |
| # Define Gradio interface | |
| def create_demo(): | |
| with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo: | |
| gr.Markdown("# Guided Stable Diffusion with Styles") | |
| with gr.Tab("Single Image Generation"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair") | |
| seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42) | |
| num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5) | |
| vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0) | |
| # Combine SD concept library tokens and art concept descriptions | |
| all_styles = ["none"] + concept_tokens + list(art_concepts.keys()) | |
| concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none") | |
| concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) | |
| generate_btn = gr.Button("Generate Image") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Generated Image", type="pil") | |
| with gr.Tab("Style Grid"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park") | |
| grid_seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Base Seed", value=42) | |
| grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30) | |
| grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5) | |
| grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0) | |
| grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5) | |
| grid_generate_btn = gr.Button("Generate Style Grid") | |
| with gr.Column(): | |
| output_grid = gr.Image(label="Style Grid", type="pil") | |
| # Set up event handlers | |
| generate_btn.click( | |
| generate_image, | |
| inputs=[prompt, seed, num_inference_steps, guidance_scale, | |
| vignette_loss_scale, concept_style, concept_strength], | |
| outputs=output_image | |
| ) | |
| grid_generate_btn.click( | |
| generate_style_grid, | |
| inputs=[grid_prompt, grid_seed, grid_num_inference_steps, | |
| grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength], | |
| outputs=output_grid | |
| ) | |
| return demo | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.launch(debug=False, show_error=True, server_name="0.0.0.0", server_port=7860, cache_examples=True) | |