Spaces:
Paused
Paused
| import gc | |
| from functools import lru_cache | |
| try: | |
| import spaces | |
| except ImportError: | |
| class _SpacesFallback: | |
| def GPU(*decorator_args, **decorator_kwargs): | |
| if decorator_args and callable(decorator_args[0]) and not decorator_kwargs: | |
| return decorator_args[0] | |
| def decorator(func): | |
| return func | |
| return decorator | |
| spaces = _SpacesFallback() | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| from PIL import Image, ImageDraw | |
| from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline | |
| APP_TITLE = "Stable Diffusion Equation Playground" | |
| DEFAULT_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5" | |
| DEFAULT_PROMPT_A = "a cozy treehouse in a forest" | |
| DEFAULT_PROMPT_B = "an underwater coral reef" | |
| DEFAULT_PROMPT_C = "a colorful outer space nebula" | |
| DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, distorted" | |
| MAX_SEED = 2_147_483_647 | |
| PROMPT_MATH_CODE = """# Diffusers normally hides this inside pipe(prompt). | |
| # In this app, each prompt becomes a CLIP text embedding first. | |
| embed_a, negative = encode_prompt(prompt_a, negative_prompt) | |
| embed_b, _ = encode_prompt(prompt_b, negative_prompt) | |
| embed_c, _ = encode_prompt(prompt_c, negative_prompt) | |
| # The sliders choose the strength of each idea. | |
| total = strength_a + strength_b + strength_c | |
| if total <= 0: | |
| wa = wb = wc = 1 / 3 | |
| else: | |
| wa = strength_a / total | |
| wb = strength_b / total | |
| wc = strength_c / total | |
| prompt_embeds = wa * embed_a + wb * embed_b + wc * embed_c | |
| """ | |
| LATENT_MATH_CODE = """# Stable Diffusion does not start from pixels. | |
| # It starts from noisy latents in the VAE's compressed image space. | |
| noise_a = torch.randn(latent_shape, generator=seed_a) | |
| noise_b = torch.randn(latent_shape, generator=seed_b) | |
| # This hidden lever lets students mix the starting noise too. | |
| latents = (1 - noise_mix) * noise_a + noise_mix * noise_b | |
| latents = (latents - latents.mean()) / latents.std() | |
| latents = latents * scheduler.init_noise_sigma | |
| """ | |
| GUIDANCE_MATH_CODE = """# Classifier-free guidance combines two UNet predictions: | |
| # one conditioned on the negative/unconditional prompt, one on the prompt. | |
| noise_negative, noise_prompt = noise_pred.chunk(2) | |
| delta = noise_prompt - noise_negative | |
| # Standard CFG is: | |
| # guided = noise_negative + guidance_scale * delta | |
| guided = noise_negative + guidance_scale * delta | |
| """ | |
| def current_device(): | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def model_dtype(device): | |
| if device.type == "cuda": | |
| return torch.float16 | |
| return torch.float32 | |
| def device_label(device): | |
| if device.type == "cuda": | |
| name = torch.cuda.get_device_name(0) | |
| return f"CUDA GPU: {name}" | |
| if device.type == "mps": | |
| return "Apple MPS GPU" | |
| return "CPU only. This app will load, but image generation will be very slow." | |
| def round_to_multiple_of_8(value): | |
| value = int(value) | |
| return max(256, min(768, 8 * round(value / 8))) | |
| def seed_generator(seed, device): | |
| seed = int(seed) % MAX_SEED | |
| if device.type == "cuda": | |
| return torch.Generator(device=device).manual_seed(seed) | |
| return torch.Generator(device="cpu").manual_seed(seed) | |
| def randn_tensor(shape, seed, device, dtype): | |
| generator = seed_generator(seed, device) | |
| if device.type == "cuda": | |
| return torch.randn(shape, generator=generator, device=device, dtype=dtype) | |
| return torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) | |
| def blank_image(message="Run generation to make an image."): | |
| image = Image.new("RGB", (512, 512), (25, 29, 36)) | |
| draw = ImageDraw.Draw(image) | |
| draw.text((32, 236), message, fill=(230, 235, 240)) | |
| return image | |
| def normalized_prompt_weights(weight_a, weight_b, weight_c): | |
| weights = [max(0.0, float(weight_a)), max(0.0, float(weight_b)), max(0.0, float(weight_c))] | |
| total = sum(weights) | |
| if total <= 0: | |
| return 1 / 3, 1 / 3, 1 / 3 | |
| return tuple(weight / total for weight in weights) | |
| def compact_prompt(text, fallback): | |
| text = " ".join(str(text or fallback).split()) | |
| text = text.replace("`", "'").replace('"', "'") | |
| return text[:52] + ("..." if len(text) > 52 else "") | |
| def prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c): | |
| weight_a, weight_b, weight_c = normalized_prompt_weights(weight_a, weight_b, weight_c) | |
| return ( | |
| f"prompt_embedding = {weight_a:.2f} * A + {weight_b:.2f} * B + {weight_c:.2f} * C", | |
| weight_a, | |
| weight_b, | |
| weight_c, | |
| ) | |
| def equation_markdown(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c): | |
| equation, weight_a, weight_b, weight_c = prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c) | |
| return ( | |
| "### Current Equation\n" | |
| f"`{equation}`\n\n" | |
| f"**A** = {compact_prompt(prompt_a, 'Prompt A')} \n" | |
| f"**B** = {compact_prompt(prompt_b, 'Prompt B')} \n" | |
| f"**C** = {compact_prompt(prompt_c, 'Prompt C')}" | |
| ) | |
| def load_pipe(model_id, device_type): | |
| device = torch.device(device_type) | |
| dtype = model_dtype(device) | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| model_id, | |
| torch_dtype=dtype, | |
| use_safetensors=True, | |
| ) | |
| pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) | |
| pipe = pipe.to(device) | |
| pipe.set_progress_bar_config(disable=True) | |
| pipe.enable_vae_slicing() | |
| if device.type == "cuda": | |
| try: | |
| pipe.enable_xformers_memory_efficient_attention() | |
| except Exception: | |
| pass | |
| return pipe | |
| def encode_prompt(pipe, prompt, negative_prompt, device): | |
| if hasattr(pipe, "encode_prompt"): | |
| prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( | |
| prompt=prompt, | |
| device=device, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| return prompt_embeds, negative_prompt_embeds | |
| combined = pipe._encode_prompt( | |
| prompt=prompt, | |
| device=device, | |
| num_images_per_prompt=1, | |
| do_classifier_free_guidance=True, | |
| negative_prompt=negative_prompt, | |
| ) | |
| negative_prompt_embeds, prompt_embeds = combined.chunk(2) | |
| return prompt_embeds, negative_prompt_embeds | |
| def cosine_similarity(a, b): | |
| a = a.detach().float().flatten() | |
| b = b.detach().float().flatten() | |
| return float(torch.nn.functional.cosine_similarity(a, b, dim=0).cpu()) | |
| def mix_prompt_embeddings( | |
| pipe, | |
| device, | |
| prompt_a, | |
| prompt_b, | |
| prompt_c, | |
| negative_prompt, | |
| weight_a, | |
| weight_b, | |
| weight_c, | |
| renormalize_prompt, | |
| ): | |
| emb_a, negative_embeds = encode_prompt(pipe, prompt_a, negative_prompt, device) | |
| emb_b, _ = encode_prompt(pipe, prompt_b, negative_prompt, device) | |
| emb_c, _ = encode_prompt(pipe, prompt_c or "", negative_prompt, device) | |
| formula, weight_a, weight_b, weight_c = prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c) | |
| mixed = weight_a * emb_a + weight_b * emb_b + weight_c * emb_c | |
| original_norm = emb_a.detach().float().norm() | |
| mixed_norm = mixed.detach().float().norm() | |
| if renormalize_prompt and float(mixed_norm.cpu()) > 0: | |
| mixed = mixed * (original_norm / mixed_norm) | |
| formula += "; then rescale to A's embedding norm" | |
| metrics = [ | |
| ["cosine(A, B)", round(cosine_similarity(emb_a, emb_b), 4)], | |
| ["cosine(A, mixed)", round(cosine_similarity(emb_a, mixed), 4)], | |
| ["cosine(B, mixed)", round(cosine_similarity(emb_b, mixed), 4)], | |
| ["cosine(C, mixed)", round(cosine_similarity(emb_c, mixed), 4)], | |
| ["norm(A)", round(float(original_norm.cpu()), 3)], | |
| ["norm(mixed)", round(float(mixed.detach().float().norm().cpu()), 3)], | |
| ] | |
| return mixed, negative_embeds, formula, metrics | |
| def prepare_latents(pipe, device, height, width, seed_a, seed_b, noise_mix, renormalize_noise): | |
| channels = int(pipe.unet.config.in_channels) | |
| latent_shape = (1, channels, height // pipe.vae_scale_factor, width // pipe.vae_scale_factor) | |
| dtype = model_dtype(device) | |
| noise_a = randn_tensor(latent_shape, seed_a, device, dtype) | |
| noise_b = randn_tensor(latent_shape, seed_b, device, dtype) | |
| latents = (1.0 - noise_mix) * noise_a + noise_mix * noise_b | |
| before_std = float(latents.detach().float().std().cpu()) | |
| if renormalize_noise: | |
| latents = (latents - latents.mean()) / (latents.std() + 1e-6) | |
| after_std = float(latents.detach().float().std().cpu()) | |
| latents = latents * pipe.scheduler.init_noise_sigma | |
| formula = f"noise = {(1.0 - noise_mix):.2f} * seed A + {noise_mix:.2f} * seed B" | |
| if renormalize_noise: | |
| formula += "; then renormalize to unit standard deviation" | |
| metrics = [ | |
| ["latent shape", str(tuple(latent_shape))], | |
| ["std before scheduler scale", round(before_std, 4)], | |
| ["std after optional renorm", round(after_std, 4)], | |
| ["scheduler init sigma", round(float(pipe.scheduler.init_noise_sigma), 4)], | |
| ] | |
| return latents, formula, metrics | |
| def apply_classifier_free_guidance(noise_negative, noise_prompt, guidance_scale): | |
| delta = noise_prompt - noise_negative | |
| guided = noise_negative + guidance_scale * delta | |
| formula = f"guided = negative + {guidance_scale:.2f} * (prompt - negative)" | |
| return guided, formula | |
| def decode_latents(pipe, latents): | |
| latents = latents / pipe.vae.config.scaling_factor | |
| image = pipe.vae.decode(latents, return_dict=False)[0] | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() | |
| image = (image * 255).round().astype("uint8") | |
| return [Image.fromarray(frame) for frame in image] | |
| def checkpoint_indices(num_steps): | |
| last = max(0, int(num_steps) - 1) | |
| return sorted({0, last // 3, (2 * last) // 3, last}) | |
| def gpu_duration(*args): | |
| try: | |
| steps = int(args[-3]) | |
| width = int(args[-2]) | |
| height = int(args[-1]) | |
| except Exception: | |
| return 90 | |
| pixel_factor = max(1.0, (width * height) / (512 * 512)) | |
| return min(180, max(60, int(35 + steps * 2.5 * pixel_factor))) | |
| def generate( | |
| prompt_a, | |
| prompt_b, | |
| prompt_c, | |
| weight_a, | |
| weight_b, | |
| weight_c, | |
| seed_a, | |
| seed_b, | |
| noise_mix, | |
| negative_prompt, | |
| guidance_scale, | |
| num_steps, | |
| width, | |
| height, | |
| ): | |
| device = current_device() | |
| if device.type == "cpu": | |
| return ( | |
| blank_image("GPU recommended."), | |
| [], | |
| "No GPU was detected. The app is designed for CUDA or MPS. It can run on CPU, but it may take a very long time.", | |
| [], | |
| ) | |
| width = round_to_multiple_of_8(width) | |
| height = round_to_multiple_of_8(height) | |
| num_steps = int(num_steps) | |
| scheduler_name = "DPM++ 2M" | |
| pipe = load_pipe(DEFAULT_MODEL, device.type) | |
| prompt_embeds, negative_prompt_embeds, prompt_formula, prompt_metrics = mix_prompt_embeddings( | |
| pipe, | |
| device, | |
| prompt_a or "", | |
| prompt_b or "", | |
| prompt_c or "", | |
| negative_prompt or "", | |
| float(weight_a), | |
| float(weight_b), | |
| float(weight_c), | |
| True, | |
| ) | |
| pipe.scheduler.set_timesteps(num_steps, device=device) | |
| latents, latent_formula, latent_metrics = prepare_latents( | |
| pipe, | |
| device, | |
| height, | |
| width, | |
| int(seed_a), | |
| int(seed_b), | |
| float(noise_mix), | |
| True, | |
| ) | |
| text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) | |
| snapshots = [] | |
| save_at = checkpoint_indices(num_steps) | |
| last_formula = "" | |
| for step_index, timestep in enumerate(pipe.scheduler.timesteps): | |
| latent_model_input = torch.cat([latents] * 2) | |
| latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, timestep) | |
| noise_pred = pipe.unet( | |
| latent_model_input, | |
| timestep, | |
| encoder_hidden_states=text_embeds, | |
| return_dict=False, | |
| )[0] | |
| noise_negative, noise_prompt = noise_pred.chunk(2) | |
| guided, last_formula = apply_classifier_free_guidance( | |
| noise_negative, | |
| noise_prompt, | |
| float(guidance_scale), | |
| ) | |
| latents = pipe.scheduler.step(guided, timestep, latents, return_dict=False)[0] | |
| if step_index in save_at: | |
| snapshot = decode_latents(pipe, latents)[0] | |
| snapshots.append((snapshot, f"step {step_index + 1} of {num_steps}")) | |
| final_image = decode_latents(pipe, latents)[0] | |
| metrics = prompt_metrics + latent_metrics + [ | |
| ["device", device_label(device)], | |
| ["prompt formula", prompt_formula], | |
| ["noise formula", latent_formula], | |
| ["guidance formula", last_formula], | |
| ] | |
| summary = ( | |
| f"Prompt blend: {prompt_formula}\n\n" | |
| f"Starting noise: {latent_formula}\n\n" | |
| f"Guidance: {last_formula}\n\n" | |
| f"Steps: {num_steps}; size: {width}x{height}; scheduler: {scheduler_name}\n\n" | |
| "The sliders blend text embeddings, not pixels. Stable Diffusion starts from noise and uses this blended prompt " | |
| "to steer each denoising step." | |
| ) | |
| if device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return final_image, snapshots, summary, metrics | |
| def randomize_seeds(): | |
| rng = np.random.default_rng() | |
| return int(rng.integers(0, MAX_SEED)), int(rng.integers(0, MAX_SEED)) | |
| def build_app(): | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="emerald", | |
| neutral_hue="slate", | |
| radius_size="sm", | |
| ) | |
| css = """ | |
| .snapshot-gallery img { object-fit: contain !important; } | |
| .code-panel textarea, .code-panel pre { font-size: 13px !important; } | |
| """ | |
| metric_headers = ["quantity", "value"] | |
| with gr.Blocks(title=APP_TITLE, theme=theme, css=css) as demo: | |
| gr.Markdown( | |
| f"# {APP_TITLE}\n" | |
| "Blend three ideas, then watch Stable Diffusion turn noise into an image using that blended prompt embedding." | |
| ) | |
| equation_preview = gr.Markdown( | |
| equation_markdown(DEFAULT_PROMPT_A, DEFAULT_PROMPT_B, DEFAULT_PROMPT_C, 1, 1, 1) | |
| ) | |
| width = gr.State(512) | |
| height = gr.State(512) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1, min_width=320): | |
| with gr.Group(): | |
| prompt_a = gr.Textbox(value=DEFAULT_PROMPT_A, label="Prompt A", lines=2) | |
| strength_a = gr.Slider(0, 3, value=1, step=0.05, label="Strength A") | |
| prompt_b = gr.Textbox(value=DEFAULT_PROMPT_B, label="Prompt B", lines=2) | |
| strength_b = gr.Slider(0, 3, value=1, step=0.05, label="Strength B") | |
| prompt_c = gr.Textbox(value=DEFAULT_PROMPT_C, label="Prompt C", lines=2) | |
| strength_c = gr.Slider(0, 3, value=1, step=0.05, label="Strength C") | |
| with gr.Accordion("A Few Diffusers Levers", open=True): | |
| guidance_scale = gr.Slider(1, 14, value=7.5, step=0.5, label="Prompt guidance") | |
| num_steps = gr.Slider(8, 35, value=20, step=1, label="Denoising steps") | |
| with gr.Row(): | |
| seed_a = gr.Number(value=11, precision=0, label="Starting noise seed") | |
| random_seeds = gr.Button("Random seed") | |
| negative_prompt = gr.Textbox( | |
| value=DEFAULT_NEGATIVE_PROMPT, | |
| label="Things to avoid", | |
| lines=1, | |
| ) | |
| with gr.Accordion("Extra Noise Mixer", open=False): | |
| with gr.Row(): | |
| seed_b = gr.Number(value=2222, precision=0, label="Second noise seed") | |
| noise_mix = gr.Slider(0, 1, value=0.0, step=0.05, label="Second seed strength") | |
| generate_button = gr.Button("Generate", variant="primary") | |
| with gr.Column(scale=1, min_width=420): | |
| output_image = gr.Image( | |
| value=blank_image(), | |
| label="Generated image", | |
| type="pil", | |
| interactive=False, | |
| ) | |
| summary = gr.Textbox(label="What happened", lines=12, interactive=False) | |
| with gr.Accordion("Denoising Snapshots", open=False): | |
| snapshots = gr.Gallery( | |
| label="Decoded latent snapshots", | |
| columns=2, | |
| height=420, | |
| object_fit="contain", | |
| elem_classes=["snapshot-gallery"], | |
| ) | |
| with gr.Accordion("Embedding Measurements", open=False): | |
| metrics = gr.Dataframe( | |
| headers=metric_headers, | |
| datatype=["str", "str"], | |
| label="Embedding and latent measurements", | |
| interactive=False, | |
| ) | |
| with gr.Accordion("Code Cells", open=False): | |
| with gr.Row(equal_height=False): | |
| gr.Code(PROMPT_MATH_CODE, language="python", label="Prompt embedding math", interactive=False, elem_classes=["code-panel"]) | |
| gr.Code(LATENT_MATH_CODE, language="python", label="Latent noise math", interactive=False, elem_classes=["code-panel"]) | |
| gr.Code(GUIDANCE_MATH_CODE, language="python", label="Guidance equation", interactive=False, elem_classes=["code-panel"]) | |
| random_seeds.click( | |
| randomize_seeds, | |
| inputs=None, | |
| outputs=[seed_a, seed_b], | |
| show_progress="hidden", | |
| ) | |
| for equation_input in [prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c]: | |
| equation_input.change( | |
| equation_markdown, | |
| inputs=[prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c], | |
| outputs=[equation_preview], | |
| show_progress="hidden", | |
| ) | |
| generate_button.click( | |
| generate, | |
| inputs=[ | |
| prompt_a, | |
| prompt_b, | |
| prompt_c, | |
| strength_a, | |
| strength_b, | |
| strength_c, | |
| seed_a, | |
| seed_b, | |
| noise_mix, | |
| negative_prompt, | |
| guidance_scale, | |
| num_steps, | |
| width, | |
| height, | |
| ], | |
| outputs=[output_image, snapshots, summary, metrics], | |
| show_progress="full", | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| build_app().queue(max_size=8).launch() | |