StableDiffusion / app.py
Shilpaj's picture
Upload 2 files
e7f5c3d verified
raw
history blame
12.8 kB
#!/usr/bin/env python3
"""
Gradio Application for Stable Diffusion
Author: Shilpaj Bhalerao
Date: Feb 26, 2025
"""
import os
import torch
import gradio as gr
import spaces
from tqdm.auto import tqdm
import numpy as np
from PIL import Image
from utils import (
load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
vignette_loss, get_concept_embedding, load_concept_library, image_grid
)
from diffusers import StableDiffusionPipeline
# Set device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if device == "mps":
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
# Load model with proper caching
@spaces.GPU
def load_model():
return StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
torch_dtype=torch.float16,
safety_checker=None
)
@spaces.GPU
@gr.Cache()
def get_pipeline():
pipe = load_model()
return pipe.to("cuda")
# Load concept library
concept_embeds, concept_tokens = load_concept_library(get_pipeline())
# Define art style concepts
art_concepts = {
"sketch_painting": "a sketch painting, pencil drawing, hand-drawn illustration",
"oil_painting": "an oil painting, textured canvas, painterly technique",
"watercolor": "a watercolor painting, fluid, soft edges",
"digital_art": "digital art, computer generated, precise details",
"comic_book": "comic book style, ink outlines, cel shading"
}
def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
vignette_loss_scale, concept_style=None, concept_strength=0.5,
height=512, width=512):
"""
Generate latents using the UNet model
Args:
prompt (str): Text prompt
seed (int): Random seed
num_inference_steps (int): Number of denoising steps
guidance_scale (float): Scale for classifier-free guidance
vignette_loss_scale (float): Scale for vignette loss
concept_style (str, optional): Style concept to use
concept_strength (float): Strength of concept influence (0.0-1.0)
height (int): Image height
width (int): Image width
Returns:
torch.Tensor: Generated latents
"""
# Set the seed
generator = torch.manual_seed(seed)
batch_size = 1
# Clear GPU memory
clear_gpu_memory()
# Get concept embedding if specified
concept_embedding = None
if concept_style:
if concept_style in concept_tokens:
# Use pre-trained concept embedding
concept_embedding = concept_embeds[concept_style].unsqueeze(0).to(device)
elif concept_style in art_concepts:
# Generate concept embedding from text description
concept_text = art_concepts[concept_style]
concept_embedding = get_concept_embedding(concept_text, get_pipeline().tokenizer, get_pipeline().text_encoder, device)
# Prep text
text_input = get_pipeline().tokenizer([prompt], padding="max_length", max_length=get_pipeline().tokenizer.model_max_length,
truncation=True, return_tensors="pt")
with torch.inference_mode():
text_embeddings = get_pipeline().text_encoder(text_input.input_ids.to(device))[0]
# Apply concept embedding influence if provided
if concept_embedding is not None and concept_strength > 0:
# Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed
if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3:
concept_embedding = concept_embedding.unsqueeze(0)
# Create weighted blend between original text embedding and concept
if text_embeddings.shape == concept_embedding.shape:
text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding
# Unconditional embedding for classifier-free guidance
max_length = text_input.input_ids.shape[-1]
uncond_input = get_pipeline().tokenizer(
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
)
with torch.inference_mode():
uncond_embeddings = get_pipeline().text_encoder(uncond_input.input_ids.to(device))[0]
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
# Prep Scheduler
set_timesteps(get_pipeline().scheduler, num_inference_steps)
# Prep latents
latents = torch.randn(
(batch_size, get_pipeline().unet.in_channels, height // 8, width // 8),
generator=generator,
)
latents = latents.to(device)
latents = latents * get_pipeline().scheduler.init_noise_sigma
# Loop through diffusion process
for i, t in tqdm(enumerate(get_pipeline().scheduler.timesteps), total=len(get_pipeline().scheduler.timesteps)):
# Expand latents for classifier-free guidance
latent_model_input = torch.cat([latents] * 2)
sigma = get_pipeline().scheduler.sigmas[i]
latent_model_input = get_pipeline().scheduler.scale_model_input(latent_model_input, t)
# Predict the noise residual
with torch.inference_mode():
noise_pred = get_pipeline().unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
# Perform classifier-free guidance
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# Apply additional guidance with vignette loss
if vignette_loss_scale > 0 and i % 5 == 0:
# Requires grad on the latents
latents = latents.detach().requires_grad_()
# Get the predicted x0
latents_x0 = latents - sigma * noise_pred
# Decode to image space
denoised_images = get_pipeline().vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
# Calculate loss
loss = vignette_loss(denoised_images) * vignette_loss_scale
# Get gradient
cond_grad = torch.autograd.grad(loss, latents)[0]
# Modify the latents based on this gradient
latents = latents.detach() - cond_grad * sigma**2
# Step with scheduler
latents = get_pipeline().scheduler.step(noise_pred, t, latents).prev_sample
return latents
@spaces.GPU
def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5,
height=512, width=512):
"""
Generate an image using Stable Diffusion
Args:
prompt (str): Text prompt
seed (int): Random seed
num_inference_steps (int): Number of denoising steps
guidance_scale (float): Scale for classifier-free guidance
vignette_loss_scale (float): Scale for vignette loss
concept_style (str): Style concept to use
concept_strength (float): Strength of concept influence (0.0-1.0)
height (int): Image height
width (int): Image width
Returns:
PIL.Image: Generated image
"""
# Handle "none" concept style
if concept_style == "none":
concept_style = None
# Generate latents
latents = generate_latents(
prompt=prompt,
seed=seed,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
vignette_loss_scale=vignette_loss_scale,
concept_style=concept_style,
concept_strength=concept_strength,
height=height,
width=width
)
# Convert latents to image
images = latents_to_pil(latents, get_pipeline().vae)
return images[0]
def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
vignette_loss_scale=0.0, concept_strength=0.5):
"""
Generate a grid of images with different style concepts
Args:
prompt (str): Text prompt
seed (int): Random seed
num_inference_steps (int): Number of denoising steps
guidance_scale (float): Scale for classifier-free guidance
vignette_loss_scale (float): Scale for vignette loss
concept_strength (float): Strength of concept influence (0.0-1.0)
Returns:
PIL.Image: Grid of generated images
"""
# List of styles to use
styles = list(art_concepts.keys())
# Generate images for each style
images = []
labels = []
for i, style in enumerate(styles):
# Generate image with this style
latents = generate_latents(
prompt=prompt,
seed=seed + i, # Use different seeds for variety
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
vignette_loss_scale=vignette_loss_scale,
concept_style=style,
concept_strength=concept_strength
)
# Convert latents to image
style_images = latents_to_pil(latents, get_pipeline().vae)
images.append(style_images[0])
labels.append(style)
# Create grid
grid = image_grid(images, 1, len(styles), labels)
return grid
# Define Gradio interface
@spaces.GPU(enable_queue=False)
def create_demo():
with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
gr.Markdown("# Guided Stable Diffusion with Styles")
with gr.Tab("Single Image Generation"):
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair")
seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
# Combine SD concept library tokens and art concept descriptions
all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
generate_btn = gr.Button("Generate Image")
with gr.Column():
output_image = gr.Image(label="Generated Image", type="pil")
with gr.Tab("Style Grid"):
with gr.Row():
with gr.Column():
grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park")
grid_seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Base Seed", value=42)
grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
grid_generate_btn = gr.Button("Generate Style Grid")
with gr.Column():
output_grid = gr.Image(label="Style Grid", type="pil")
# Set up event handlers
generate_btn.click(
generate_image,
inputs=[prompt, seed, num_inference_steps, guidance_scale,
vignette_loss_scale, concept_style, concept_strength],
outputs=output_image
)
grid_generate_btn.click(
generate_style_grid,
inputs=[grid_prompt, grid_seed, grid_num_inference_steps,
grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength],
outputs=output_grid
)
return demo
# Launch the app
if __name__ == "__main__":
demo = create_demo()
demo.launch(debug=False, show_error=True, server_name="0.0.0.0", server_port=7860, cache_examples=True)