Leeps's picture
Simplify prompt blending workshop UI
9084ca7
Raw
History Blame Contribute Delete
19.1 kB
import gc
from functools import lru_cache
try:
import spaces
except ImportError:
class _SpacesFallback:
@staticmethod
def GPU(*decorator_args, **decorator_kwargs):
if decorator_args and callable(decorator_args[0]) and not decorator_kwargs:
return decorator_args[0]
def decorator(func):
return func
return decorator
spaces = _SpacesFallback()
import gradio as gr
import numpy as np
import torch
from PIL import Image, ImageDraw
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline
APP_TITLE = "Stable Diffusion Equation Playground"
DEFAULT_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5"
DEFAULT_PROMPT_A = "a cozy treehouse in a forest"
DEFAULT_PROMPT_B = "an underwater coral reef"
DEFAULT_PROMPT_C = "a colorful outer space nebula"
DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, distorted"
MAX_SEED = 2_147_483_647
PROMPT_MATH_CODE = """# Diffusers normally hides this inside pipe(prompt).
# In this app, each prompt becomes a CLIP text embedding first.
embed_a, negative = encode_prompt(prompt_a, negative_prompt)
embed_b, _ = encode_prompt(prompt_b, negative_prompt)
embed_c, _ = encode_prompt(prompt_c, negative_prompt)
# The sliders choose the strength of each idea.
total = strength_a + strength_b + strength_c
if total <= 0:
wa = wb = wc = 1 / 3
else:
wa = strength_a / total
wb = strength_b / total
wc = strength_c / total
prompt_embeds = wa * embed_a + wb * embed_b + wc * embed_c
"""
LATENT_MATH_CODE = """# Stable Diffusion does not start from pixels.
# It starts from noisy latents in the VAE's compressed image space.
noise_a = torch.randn(latent_shape, generator=seed_a)
noise_b = torch.randn(latent_shape, generator=seed_b)
# This hidden lever lets students mix the starting noise too.
latents = (1 - noise_mix) * noise_a + noise_mix * noise_b
latents = (latents - latents.mean()) / latents.std()
latents = latents * scheduler.init_noise_sigma
"""
GUIDANCE_MATH_CODE = """# Classifier-free guidance combines two UNet predictions:
# one conditioned on the negative/unconditional prompt, one on the prompt.
noise_negative, noise_prompt = noise_pred.chunk(2)
delta = noise_prompt - noise_negative
# Standard CFG is:
# guided = noise_negative + guidance_scale * delta
guided = noise_negative + guidance_scale * delta
"""
def current_device():
if torch.cuda.is_available():
return torch.device("cuda")
if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def model_dtype(device):
if device.type == "cuda":
return torch.float16
return torch.float32
def device_label(device):
if device.type == "cuda":
name = torch.cuda.get_device_name(0)
return f"CUDA GPU: {name}"
if device.type == "mps":
return "Apple MPS GPU"
return "CPU only. This app will load, but image generation will be very slow."
def round_to_multiple_of_8(value):
value = int(value)
return max(256, min(768, 8 * round(value / 8)))
def seed_generator(seed, device):
seed = int(seed) % MAX_SEED
if device.type == "cuda":
return torch.Generator(device=device).manual_seed(seed)
return torch.Generator(device="cpu").manual_seed(seed)
def randn_tensor(shape, seed, device, dtype):
generator = seed_generator(seed, device)
if device.type == "cuda":
return torch.randn(shape, generator=generator, device=device, dtype=dtype)
return torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
def blank_image(message="Run generation to make an image."):
image = Image.new("RGB", (512, 512), (25, 29, 36))
draw = ImageDraw.Draw(image)
draw.text((32, 236), message, fill=(230, 235, 240))
return image
def normalized_prompt_weights(weight_a, weight_b, weight_c):
weights = [max(0.0, float(weight_a)), max(0.0, float(weight_b)), max(0.0, float(weight_c))]
total = sum(weights)
if total <= 0:
return 1 / 3, 1 / 3, 1 / 3
return tuple(weight / total for weight in weights)
def compact_prompt(text, fallback):
text = " ".join(str(text or fallback).split())
text = text.replace("`", "'").replace('"', "'")
return text[:52] + ("..." if len(text) > 52 else "")
def prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c):
weight_a, weight_b, weight_c = normalized_prompt_weights(weight_a, weight_b, weight_c)
return (
f"prompt_embedding = {weight_a:.2f} * A + {weight_b:.2f} * B + {weight_c:.2f} * C",
weight_a,
weight_b,
weight_c,
)
def equation_markdown(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c):
equation, weight_a, weight_b, weight_c = prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c)
return (
"### Current Equation\n"
f"`{equation}`\n\n"
f"**A** = {compact_prompt(prompt_a, 'Prompt A')} \n"
f"**B** = {compact_prompt(prompt_b, 'Prompt B')} \n"
f"**C** = {compact_prompt(prompt_c, 'Prompt C')}"
)
@lru_cache(maxsize=2)
def load_pipe(model_id, device_type):
device = torch.device(device_type)
dtype = model_dtype(device)
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
use_safetensors=True,
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=True)
pipe.enable_vae_slicing()
if device.type == "cuda":
try:
pipe.enable_xformers_memory_efficient_attention()
except Exception:
pass
return pipe
def encode_prompt(pipe, prompt, negative_prompt, device):
if hasattr(pipe, "encode_prompt"):
prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
return prompt_embeds, negative_prompt_embeds
combined = pipe._encode_prompt(
prompt=prompt,
device=device,
num_images_per_prompt=1,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
negative_prompt_embeds, prompt_embeds = combined.chunk(2)
return prompt_embeds, negative_prompt_embeds
def cosine_similarity(a, b):
a = a.detach().float().flatten()
b = b.detach().float().flatten()
return float(torch.nn.functional.cosine_similarity(a, b, dim=0).cpu())
def mix_prompt_embeddings(
pipe,
device,
prompt_a,
prompt_b,
prompt_c,
negative_prompt,
weight_a,
weight_b,
weight_c,
renormalize_prompt,
):
emb_a, negative_embeds = encode_prompt(pipe, prompt_a, negative_prompt, device)
emb_b, _ = encode_prompt(pipe, prompt_b, negative_prompt, device)
emb_c, _ = encode_prompt(pipe, prompt_c or "", negative_prompt, device)
formula, weight_a, weight_b, weight_c = prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c)
mixed = weight_a * emb_a + weight_b * emb_b + weight_c * emb_c
original_norm = emb_a.detach().float().norm()
mixed_norm = mixed.detach().float().norm()
if renormalize_prompt and float(mixed_norm.cpu()) > 0:
mixed = mixed * (original_norm / mixed_norm)
formula += "; then rescale to A's embedding norm"
metrics = [
["cosine(A, B)", round(cosine_similarity(emb_a, emb_b), 4)],
["cosine(A, mixed)", round(cosine_similarity(emb_a, mixed), 4)],
["cosine(B, mixed)", round(cosine_similarity(emb_b, mixed), 4)],
["cosine(C, mixed)", round(cosine_similarity(emb_c, mixed), 4)],
["norm(A)", round(float(original_norm.cpu()), 3)],
["norm(mixed)", round(float(mixed.detach().float().norm().cpu()), 3)],
]
return mixed, negative_embeds, formula, metrics
def prepare_latents(pipe, device, height, width, seed_a, seed_b, noise_mix, renormalize_noise):
channels = int(pipe.unet.config.in_channels)
latent_shape = (1, channels, height // pipe.vae_scale_factor, width // pipe.vae_scale_factor)
dtype = model_dtype(device)
noise_a = randn_tensor(latent_shape, seed_a, device, dtype)
noise_b = randn_tensor(latent_shape, seed_b, device, dtype)
latents = (1.0 - noise_mix) * noise_a + noise_mix * noise_b
before_std = float(latents.detach().float().std().cpu())
if renormalize_noise:
latents = (latents - latents.mean()) / (latents.std() + 1e-6)
after_std = float(latents.detach().float().std().cpu())
latents = latents * pipe.scheduler.init_noise_sigma
formula = f"noise = {(1.0 - noise_mix):.2f} * seed A + {noise_mix:.2f} * seed B"
if renormalize_noise:
formula += "; then renormalize to unit standard deviation"
metrics = [
["latent shape", str(tuple(latent_shape))],
["std before scheduler scale", round(before_std, 4)],
["std after optional renorm", round(after_std, 4)],
["scheduler init sigma", round(float(pipe.scheduler.init_noise_sigma), 4)],
]
return latents, formula, metrics
def apply_classifier_free_guidance(noise_negative, noise_prompt, guidance_scale):
delta = noise_prompt - noise_negative
guided = noise_negative + guidance_scale * delta
formula = f"guided = negative + {guidance_scale:.2f} * (prompt - negative)"
return guided, formula
def decode_latents(pipe, latents):
latents = latents / pipe.vae.config.scaling_factor
image = pipe.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
image = (image * 255).round().astype("uint8")
return [Image.fromarray(frame) for frame in image]
def checkpoint_indices(num_steps):
last = max(0, int(num_steps) - 1)
return sorted({0, last // 3, (2 * last) // 3, last})
def gpu_duration(*args):
try:
steps = int(args[-3])
width = int(args[-2])
height = int(args[-1])
except Exception:
return 90
pixel_factor = max(1.0, (width * height) / (512 * 512))
return min(180, max(60, int(35 + steps * 2.5 * pixel_factor)))
@spaces.GPU(duration=gpu_duration)
@torch.inference_mode()
def generate(
prompt_a,
prompt_b,
prompt_c,
weight_a,
weight_b,
weight_c,
seed_a,
seed_b,
noise_mix,
negative_prompt,
guidance_scale,
num_steps,
width,
height,
):
device = current_device()
if device.type == "cpu":
return (
blank_image("GPU recommended."),
[],
"No GPU was detected. The app is designed for CUDA or MPS. It can run on CPU, but it may take a very long time.",
[],
)
width = round_to_multiple_of_8(width)
height = round_to_multiple_of_8(height)
num_steps = int(num_steps)
scheduler_name = "DPM++ 2M"
pipe = load_pipe(DEFAULT_MODEL, device.type)
prompt_embeds, negative_prompt_embeds, prompt_formula, prompt_metrics = mix_prompt_embeddings(
pipe,
device,
prompt_a or "",
prompt_b or "",
prompt_c or "",
negative_prompt or "",
float(weight_a),
float(weight_b),
float(weight_c),
True,
)
pipe.scheduler.set_timesteps(num_steps, device=device)
latents, latent_formula, latent_metrics = prepare_latents(
pipe,
device,
height,
width,
int(seed_a),
int(seed_b),
float(noise_mix),
True,
)
text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
snapshots = []
save_at = checkpoint_indices(num_steps)
last_formula = ""
for step_index, timestep in enumerate(pipe.scheduler.timesteps):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, timestep)
noise_pred = pipe.unet(
latent_model_input,
timestep,
encoder_hidden_states=text_embeds,
return_dict=False,
)[0]
noise_negative, noise_prompt = noise_pred.chunk(2)
guided, last_formula = apply_classifier_free_guidance(
noise_negative,
noise_prompt,
float(guidance_scale),
)
latents = pipe.scheduler.step(guided, timestep, latents, return_dict=False)[0]
if step_index in save_at:
snapshot = decode_latents(pipe, latents)[0]
snapshots.append((snapshot, f"step {step_index + 1} of {num_steps}"))
final_image = decode_latents(pipe, latents)[0]
metrics = prompt_metrics + latent_metrics + [
["device", device_label(device)],
["prompt formula", prompt_formula],
["noise formula", latent_formula],
["guidance formula", last_formula],
]
summary = (
f"Prompt blend: {prompt_formula}\n\n"
f"Starting noise: {latent_formula}\n\n"
f"Guidance: {last_formula}\n\n"
f"Steps: {num_steps}; size: {width}x{height}; scheduler: {scheduler_name}\n\n"
"The sliders blend text embeddings, not pixels. Stable Diffusion starts from noise and uses this blended prompt "
"to steer each denoising step."
)
if device.type == "cuda":
torch.cuda.empty_cache()
gc.collect()
return final_image, snapshots, summary, metrics
def randomize_seeds():
rng = np.random.default_rng()
return int(rng.integers(0, MAX_SEED)), int(rng.integers(0, MAX_SEED))
def build_app():
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="emerald",
neutral_hue="slate",
radius_size="sm",
)
css = """
.snapshot-gallery img { object-fit: contain !important; }
.code-panel textarea, .code-panel pre { font-size: 13px !important; }
"""
metric_headers = ["quantity", "value"]
with gr.Blocks(title=APP_TITLE, theme=theme, css=css) as demo:
gr.Markdown(
f"# {APP_TITLE}\n"
"Blend three ideas, then watch Stable Diffusion turn noise into an image using that blended prompt embedding."
)
equation_preview = gr.Markdown(
equation_markdown(DEFAULT_PROMPT_A, DEFAULT_PROMPT_B, DEFAULT_PROMPT_C, 1, 1, 1)
)
width = gr.State(512)
height = gr.State(512)
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=320):
with gr.Group():
prompt_a = gr.Textbox(value=DEFAULT_PROMPT_A, label="Prompt A", lines=2)
strength_a = gr.Slider(0, 3, value=1, step=0.05, label="Strength A")
prompt_b = gr.Textbox(value=DEFAULT_PROMPT_B, label="Prompt B", lines=2)
strength_b = gr.Slider(0, 3, value=1, step=0.05, label="Strength B")
prompt_c = gr.Textbox(value=DEFAULT_PROMPT_C, label="Prompt C", lines=2)
strength_c = gr.Slider(0, 3, value=1, step=0.05, label="Strength C")
with gr.Accordion("A Few Diffusers Levers", open=True):
guidance_scale = gr.Slider(1, 14, value=7.5, step=0.5, label="Prompt guidance")
num_steps = gr.Slider(8, 35, value=20, step=1, label="Denoising steps")
with gr.Row():
seed_a = gr.Number(value=11, precision=0, label="Starting noise seed")
random_seeds = gr.Button("Random seed")
negative_prompt = gr.Textbox(
value=DEFAULT_NEGATIVE_PROMPT,
label="Things to avoid",
lines=1,
)
with gr.Accordion("Extra Noise Mixer", open=False):
with gr.Row():
seed_b = gr.Number(value=2222, precision=0, label="Second noise seed")
noise_mix = gr.Slider(0, 1, value=0.0, step=0.05, label="Second seed strength")
generate_button = gr.Button("Generate", variant="primary")
with gr.Column(scale=1, min_width=420):
output_image = gr.Image(
value=blank_image(),
label="Generated image",
type="pil",
interactive=False,
)
summary = gr.Textbox(label="What happened", lines=12, interactive=False)
with gr.Accordion("Denoising Snapshots", open=False):
snapshots = gr.Gallery(
label="Decoded latent snapshots",
columns=2,
height=420,
object_fit="contain",
elem_classes=["snapshot-gallery"],
)
with gr.Accordion("Embedding Measurements", open=False):
metrics = gr.Dataframe(
headers=metric_headers,
datatype=["str", "str"],
label="Embedding and latent measurements",
interactive=False,
)
with gr.Accordion("Code Cells", open=False):
with gr.Row(equal_height=False):
gr.Code(PROMPT_MATH_CODE, language="python", label="Prompt embedding math", interactive=False, elem_classes=["code-panel"])
gr.Code(LATENT_MATH_CODE, language="python", label="Latent noise math", interactive=False, elem_classes=["code-panel"])
gr.Code(GUIDANCE_MATH_CODE, language="python", label="Guidance equation", interactive=False, elem_classes=["code-panel"])
random_seeds.click(
randomize_seeds,
inputs=None,
outputs=[seed_a, seed_b],
show_progress="hidden",
)
for equation_input in [prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c]:
equation_input.change(
equation_markdown,
inputs=[prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c],
outputs=[equation_preview],
show_progress="hidden",
)
generate_button.click(
generate,
inputs=[
prompt_a,
prompt_b,
prompt_c,
strength_a,
strength_b,
strength_c,
seed_a,
seed_b,
noise_mix,
negative_prompt,
guidance_scale,
num_steps,
width,
height,
],
outputs=[output_image, snapshots, summary, metrics],
show_progress="full",
)
return demo
if __name__ == "__main__":
build_app().queue(max_size=8).launch()