import gc from functools import lru_cache try: import spaces except ImportError: class _SpacesFallback: @staticmethod def GPU(*decorator_args, **decorator_kwargs): if decorator_args and callable(decorator_args[0]) and not decorator_kwargs: return decorator_args[0] def decorator(func): return func return decorator spaces = _SpacesFallback() import gradio as gr import numpy as np import torch from PIL import Image, ImageDraw from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline APP_TITLE = "Stable Diffusion Equation Playground" DEFAULT_MODEL = "stable-diffusion-v1-5/stable-diffusion-v1-5" DEFAULT_PROMPT_A = "a cozy treehouse in a forest" DEFAULT_PROMPT_B = "an underwater coral reef" DEFAULT_PROMPT_C = "a colorful outer space nebula" DEFAULT_NEGATIVE_PROMPT = "blurry, low quality, distorted" MAX_SEED = 2_147_483_647 PROMPT_MATH_CODE = """# Diffusers normally hides this inside pipe(prompt). # In this app, each prompt becomes a CLIP text embedding first. embed_a, negative = encode_prompt(prompt_a, negative_prompt) embed_b, _ = encode_prompt(prompt_b, negative_prompt) embed_c, _ = encode_prompt(prompt_c, negative_prompt) # The sliders choose the strength of each idea. total = strength_a + strength_b + strength_c if total <= 0: wa = wb = wc = 1 / 3 else: wa = strength_a / total wb = strength_b / total wc = strength_c / total prompt_embeds = wa * embed_a + wb * embed_b + wc * embed_c """ LATENT_MATH_CODE = """# Stable Diffusion does not start from pixels. # It starts from noisy latents in the VAE's compressed image space. noise_a = torch.randn(latent_shape, generator=seed_a) noise_b = torch.randn(latent_shape, generator=seed_b) # This hidden lever lets students mix the starting noise too. latents = (1 - noise_mix) * noise_a + noise_mix * noise_b latents = (latents - latents.mean()) / latents.std() latents = latents * scheduler.init_noise_sigma """ GUIDANCE_MATH_CODE = """# Classifier-free guidance combines two UNet predictions: # one conditioned on the negative/unconditional prompt, one on the prompt. noise_negative, noise_prompt = noise_pred.chunk(2) delta = noise_prompt - noise_negative # Standard CFG is: # guided = noise_negative + guidance_scale * delta guided = noise_negative + guidance_scale * delta """ def current_device(): if torch.cuda.is_available(): return torch.device("cuda") if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): return torch.device("mps") return torch.device("cpu") def model_dtype(device): if device.type == "cuda": return torch.float16 return torch.float32 def device_label(device): if device.type == "cuda": name = torch.cuda.get_device_name(0) return f"CUDA GPU: {name}" if device.type == "mps": return "Apple MPS GPU" return "CPU only. This app will load, but image generation will be very slow." def round_to_multiple_of_8(value): value = int(value) return max(256, min(768, 8 * round(value / 8))) def seed_generator(seed, device): seed = int(seed) % MAX_SEED if device.type == "cuda": return torch.Generator(device=device).manual_seed(seed) return torch.Generator(device="cpu").manual_seed(seed) def randn_tensor(shape, seed, device, dtype): generator = seed_generator(seed, device) if device.type == "cuda": return torch.randn(shape, generator=generator, device=device, dtype=dtype) return torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device) def blank_image(message="Run generation to make an image."): image = Image.new("RGB", (512, 512), (25, 29, 36)) draw = ImageDraw.Draw(image) draw.text((32, 236), message, fill=(230, 235, 240)) return image def normalized_prompt_weights(weight_a, weight_b, weight_c): weights = [max(0.0, float(weight_a)), max(0.0, float(weight_b)), max(0.0, float(weight_c))] total = sum(weights) if total <= 0: return 1 / 3, 1 / 3, 1 / 3 return tuple(weight / total for weight in weights) def compact_prompt(text, fallback): text = " ".join(str(text or fallback).split()) text = text.replace("`", "'").replace('"', "'") return text[:52] + ("..." if len(text) > 52 else "") def prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c): weight_a, weight_b, weight_c = normalized_prompt_weights(weight_a, weight_b, weight_c) return ( f"prompt_embedding = {weight_a:.2f} * A + {weight_b:.2f} * B + {weight_c:.2f} * C", weight_a, weight_b, weight_c, ) def equation_markdown(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c): equation, weight_a, weight_b, weight_c = prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c) return ( "### Current Equation\n" f"`{equation}`\n\n" f"**A** = {compact_prompt(prompt_a, 'Prompt A')} \n" f"**B** = {compact_prompt(prompt_b, 'Prompt B')} \n" f"**C** = {compact_prompt(prompt_c, 'Prompt C')}" ) @lru_cache(maxsize=2) def load_pipe(model_id, device_type): device = torch.device(device_type) dtype = model_dtype(device) pipe = StableDiffusionPipeline.from_pretrained( model_id, torch_dtype=dtype, use_safetensors=True, ) pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=True) pipe.enable_vae_slicing() if device.type == "cuda": try: pipe.enable_xformers_memory_efficient_attention() except Exception: pass return pipe def encode_prompt(pipe, prompt, negative_prompt, device): if hasattr(pipe, "encode_prompt"): prompt_embeds, negative_prompt_embeds = pipe.encode_prompt( prompt=prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) return prompt_embeds, negative_prompt_embeds combined = pipe._encode_prompt( prompt=prompt, device=device, num_images_per_prompt=1, do_classifier_free_guidance=True, negative_prompt=negative_prompt, ) negative_prompt_embeds, prompt_embeds = combined.chunk(2) return prompt_embeds, negative_prompt_embeds def cosine_similarity(a, b): a = a.detach().float().flatten() b = b.detach().float().flatten() return float(torch.nn.functional.cosine_similarity(a, b, dim=0).cpu()) def mix_prompt_embeddings( pipe, device, prompt_a, prompt_b, prompt_c, negative_prompt, weight_a, weight_b, weight_c, renormalize_prompt, ): emb_a, negative_embeds = encode_prompt(pipe, prompt_a, negative_prompt, device) emb_b, _ = encode_prompt(pipe, prompt_b, negative_prompt, device) emb_c, _ = encode_prompt(pipe, prompt_c or "", negative_prompt, device) formula, weight_a, weight_b, weight_c = prompt_equation(prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c) mixed = weight_a * emb_a + weight_b * emb_b + weight_c * emb_c original_norm = emb_a.detach().float().norm() mixed_norm = mixed.detach().float().norm() if renormalize_prompt and float(mixed_norm.cpu()) > 0: mixed = mixed * (original_norm / mixed_norm) formula += "; then rescale to A's embedding norm" metrics = [ ["cosine(A, B)", round(cosine_similarity(emb_a, emb_b), 4)], ["cosine(A, mixed)", round(cosine_similarity(emb_a, mixed), 4)], ["cosine(B, mixed)", round(cosine_similarity(emb_b, mixed), 4)], ["cosine(C, mixed)", round(cosine_similarity(emb_c, mixed), 4)], ["norm(A)", round(float(original_norm.cpu()), 3)], ["norm(mixed)", round(float(mixed.detach().float().norm().cpu()), 3)], ] return mixed, negative_embeds, formula, metrics def prepare_latents(pipe, device, height, width, seed_a, seed_b, noise_mix, renormalize_noise): channels = int(pipe.unet.config.in_channels) latent_shape = (1, channels, height // pipe.vae_scale_factor, width // pipe.vae_scale_factor) dtype = model_dtype(device) noise_a = randn_tensor(latent_shape, seed_a, device, dtype) noise_b = randn_tensor(latent_shape, seed_b, device, dtype) latents = (1.0 - noise_mix) * noise_a + noise_mix * noise_b before_std = float(latents.detach().float().std().cpu()) if renormalize_noise: latents = (latents - latents.mean()) / (latents.std() + 1e-6) after_std = float(latents.detach().float().std().cpu()) latents = latents * pipe.scheduler.init_noise_sigma formula = f"noise = {(1.0 - noise_mix):.2f} * seed A + {noise_mix:.2f} * seed B" if renormalize_noise: formula += "; then renormalize to unit standard deviation" metrics = [ ["latent shape", str(tuple(latent_shape))], ["std before scheduler scale", round(before_std, 4)], ["std after optional renorm", round(after_std, 4)], ["scheduler init sigma", round(float(pipe.scheduler.init_noise_sigma), 4)], ] return latents, formula, metrics def apply_classifier_free_guidance(noise_negative, noise_prompt, guidance_scale): delta = noise_prompt - noise_negative guided = noise_negative + guidance_scale * delta formula = f"guided = negative + {guidance_scale:.2f} * (prompt - negative)" return guided, formula def decode_latents(pipe, latents): latents = latents / pipe.vae.config.scaling_factor image = pipe.vae.decode(latents, return_dict=False)[0] image = (image / 2 + 0.5).clamp(0, 1) image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy() image = (image * 255).round().astype("uint8") return [Image.fromarray(frame) for frame in image] def checkpoint_indices(num_steps): last = max(0, int(num_steps) - 1) return sorted({0, last // 3, (2 * last) // 3, last}) def gpu_duration(*args): try: steps = int(args[-3]) width = int(args[-2]) height = int(args[-1]) except Exception: return 90 pixel_factor = max(1.0, (width * height) / (512 * 512)) return min(180, max(60, int(35 + steps * 2.5 * pixel_factor))) @spaces.GPU(duration=gpu_duration) @torch.inference_mode() def generate( prompt_a, prompt_b, prompt_c, weight_a, weight_b, weight_c, seed_a, seed_b, noise_mix, negative_prompt, guidance_scale, num_steps, width, height, ): device = current_device() if device.type == "cpu": return ( blank_image("GPU recommended."), [], "No GPU was detected. The app is designed for CUDA or MPS. It can run on CPU, but it may take a very long time.", [], ) width = round_to_multiple_of_8(width) height = round_to_multiple_of_8(height) num_steps = int(num_steps) scheduler_name = "DPM++ 2M" pipe = load_pipe(DEFAULT_MODEL, device.type) prompt_embeds, negative_prompt_embeds, prompt_formula, prompt_metrics = mix_prompt_embeddings( pipe, device, prompt_a or "", prompt_b or "", prompt_c or "", negative_prompt or "", float(weight_a), float(weight_b), float(weight_c), True, ) pipe.scheduler.set_timesteps(num_steps, device=device) latents, latent_formula, latent_metrics = prepare_latents( pipe, device, height, width, int(seed_a), int(seed_b), float(noise_mix), True, ) text_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) snapshots = [] save_at = checkpoint_indices(num_steps) last_formula = "" for step_index, timestep in enumerate(pipe.scheduler.timesteps): latent_model_input = torch.cat([latents] * 2) latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, timestep) noise_pred = pipe.unet( latent_model_input, timestep, encoder_hidden_states=text_embeds, return_dict=False, )[0] noise_negative, noise_prompt = noise_pred.chunk(2) guided, last_formula = apply_classifier_free_guidance( noise_negative, noise_prompt, float(guidance_scale), ) latents = pipe.scheduler.step(guided, timestep, latents, return_dict=False)[0] if step_index in save_at: snapshot = decode_latents(pipe, latents)[0] snapshots.append((snapshot, f"step {step_index + 1} of {num_steps}")) final_image = decode_latents(pipe, latents)[0] metrics = prompt_metrics + latent_metrics + [ ["device", device_label(device)], ["prompt formula", prompt_formula], ["noise formula", latent_formula], ["guidance formula", last_formula], ] summary = ( f"Prompt blend: {prompt_formula}\n\n" f"Starting noise: {latent_formula}\n\n" f"Guidance: {last_formula}\n\n" f"Steps: {num_steps}; size: {width}x{height}; scheduler: {scheduler_name}\n\n" "The sliders blend text embeddings, not pixels. Stable Diffusion starts from noise and uses this blended prompt " "to steer each denoising step." ) if device.type == "cuda": torch.cuda.empty_cache() gc.collect() return final_image, snapshots, summary, metrics def randomize_seeds(): rng = np.random.default_rng() return int(rng.integers(0, MAX_SEED)), int(rng.integers(0, MAX_SEED)) def build_app(): theme = gr.themes.Soft( primary_hue="indigo", secondary_hue="emerald", neutral_hue="slate", radius_size="sm", ) css = """ .snapshot-gallery img { object-fit: contain !important; } .code-panel textarea, .code-panel pre { font-size: 13px !important; } """ metric_headers = ["quantity", "value"] with gr.Blocks(title=APP_TITLE, theme=theme, css=css) as demo: gr.Markdown( f"# {APP_TITLE}\n" "Blend three ideas, then watch Stable Diffusion turn noise into an image using that blended prompt embedding." ) equation_preview = gr.Markdown( equation_markdown(DEFAULT_PROMPT_A, DEFAULT_PROMPT_B, DEFAULT_PROMPT_C, 1, 1, 1) ) width = gr.State(512) height = gr.State(512) with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=320): with gr.Group(): prompt_a = gr.Textbox(value=DEFAULT_PROMPT_A, label="Prompt A", lines=2) strength_a = gr.Slider(0, 3, value=1, step=0.05, label="Strength A") prompt_b = gr.Textbox(value=DEFAULT_PROMPT_B, label="Prompt B", lines=2) strength_b = gr.Slider(0, 3, value=1, step=0.05, label="Strength B") prompt_c = gr.Textbox(value=DEFAULT_PROMPT_C, label="Prompt C", lines=2) strength_c = gr.Slider(0, 3, value=1, step=0.05, label="Strength C") with gr.Accordion("A Few Diffusers Levers", open=True): guidance_scale = gr.Slider(1, 14, value=7.5, step=0.5, label="Prompt guidance") num_steps = gr.Slider(8, 35, value=20, step=1, label="Denoising steps") with gr.Row(): seed_a = gr.Number(value=11, precision=0, label="Starting noise seed") random_seeds = gr.Button("Random seed") negative_prompt = gr.Textbox( value=DEFAULT_NEGATIVE_PROMPT, label="Things to avoid", lines=1, ) with gr.Accordion("Extra Noise Mixer", open=False): with gr.Row(): seed_b = gr.Number(value=2222, precision=0, label="Second noise seed") noise_mix = gr.Slider(0, 1, value=0.0, step=0.05, label="Second seed strength") generate_button = gr.Button("Generate", variant="primary") with gr.Column(scale=1, min_width=420): output_image = gr.Image( value=blank_image(), label="Generated image", type="pil", interactive=False, ) summary = gr.Textbox(label="What happened", lines=12, interactive=False) with gr.Accordion("Denoising Snapshots", open=False): snapshots = gr.Gallery( label="Decoded latent snapshots", columns=2, height=420, object_fit="contain", elem_classes=["snapshot-gallery"], ) with gr.Accordion("Embedding Measurements", open=False): metrics = gr.Dataframe( headers=metric_headers, datatype=["str", "str"], label="Embedding and latent measurements", interactive=False, ) with gr.Accordion("Code Cells", open=False): with gr.Row(equal_height=False): gr.Code(PROMPT_MATH_CODE, language="python", label="Prompt embedding math", interactive=False, elem_classes=["code-panel"]) gr.Code(LATENT_MATH_CODE, language="python", label="Latent noise math", interactive=False, elem_classes=["code-panel"]) gr.Code(GUIDANCE_MATH_CODE, language="python", label="Guidance equation", interactive=False, elem_classes=["code-panel"]) random_seeds.click( randomize_seeds, inputs=None, outputs=[seed_a, seed_b], show_progress="hidden", ) for equation_input in [prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c]: equation_input.change( equation_markdown, inputs=[prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c], outputs=[equation_preview], show_progress="hidden", ) generate_button.click( generate, inputs=[ prompt_a, prompt_b, prompt_c, strength_a, strength_b, strength_c, seed_a, seed_b, noise_mix, negative_prompt, guidance_scale, num_steps, width, height, ], outputs=[output_image, snapshots, summary, metrics], show_progress="full", ) return demo if __name__ == "__main__": build_app().queue(max_size=8).launch()