# IMAGE DIFFUSION VISUALIZER β€” ADVANCED # Visualizes how a (tiny) Stable Diffusion model denoises step by step. # Model: hf-internal-testing/tiny-stable-diffusion-pipe (small, CPU-safe, for demos) import gradio as gr import torch import numpy as np from diffusers import DiffusionPipeline from sklearn.decomposition import PCA import plotly.graph_objects as go import plotly.express as px from PIL import Image import time DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-pipe" PIPE_CACHE = None # -------------------- MODEL LOADING -------------------- # def get_pipe(): """Lazy-load and cache the tiny Stable Diffusion pipeline.""" global PIPE_CACHE if PIPE_CACHE is not None: return PIPE_CACHE pipe = DiffusionPipeline.from_pretrained(MODEL_ID) pipe.to(DEVICE) pipe.safety_checker = None # tiny pipe usually doesn't have NSFW issues; keep simple PIPE_CACHE = pipe return PIPE_CACHE # -------------------- CORE UTILS -------------------- # def decode_latent_to_pil(pipe, latent_np): """ Decode a latent (C,H,W) numpy array to a PIL image using the VAE. Works for intermediate steps too. """ vae = pipe.vae latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) # scaling_factor is used in SD-style VAEs; fallback to standard SD value scale = getattr(vae.config, "scaling_factor", 0.18215) with torch.no_grad(): image = vae.decode(latent / scale).sample image = (image / 2 + 0.5).clamp(0, 1) image = image[0].permute(1, 2, 0).cpu().numpy() image = (image * 255).astype("uint8") return Image.fromarray(image) def compute_pca_over_steps(latents_list): """ latents_list: list of (C,H,W) numpy arrays. Flatten each into a single vector; run PCA across steps. Returns (S,2) array of 2D coords. """ if len(latents_list) == 0: return None flat = [x.reshape(-1) for x in latents_list] mat = np.stack(flat, axis=0) # (steps, dim) if mat.shape[0] < 2 or mat.shape[1] < 2: # Not enough data for PCA; return zeros return np.zeros((mat.shape[0], 2)) try: pca = PCA(n_components=2) pts = pca.fit_transform(mat) return pts except Exception: return np.zeros((mat.shape[0], 2)) def compute_norms_over_steps(latents_list): """Compute L2 norm of each latent across channels & spatial dims.""" if len(latents_list) == 0: return [] flat = [x.reshape(-1) for x in latents_list] norms = [float(np.linalg.norm(v)) for v in flat] return norms def explain(simple=True): if simple: return ( "πŸ§’ **Simple explanation of what you see:**\n\n" "1. The model starts with a totally noisy image.\n" "2. Step by step, it removes noise and shapes the picture.\n" "3. Your words (the prompt) tell it *what* to draw.\n" "4. The slider lets you move through these steps:\n" " - Early steps = mostly noise\n" " - Later steps = clearer image\n" ) else: return ( "πŸ”¬ **Technical explanation:**\n\n" "- We use a tiny Stable Diffusion-style pipeline.\n" "- At each timestep `t`, the UNet predicts noise Ξ΅β‚œ for latent `zβ‚œ`.\n" "- The scheduler updates `zβ‚œ β†’ zβ‚œβ‚‹β‚` using Ξ΅β‚œ.\n" "- We record the latent after each step and decode it with the VAE.\n" "- PCA over flattened latents shows the trajectory in latent space.\n" "- Latent norm vs step shows how the magnitude evolves during denoising.\n" ) def make_pca_figure(points, current_idx): """Make a PCA trajectory plot over steps, highlighting the selected step.""" steps = list(range(len(points))) fig = go.Figure() fig.add_trace(go.Scatter( x=points[:, 0], y=points[:, 1], mode="lines+markers", name="Steps", text=[f"step {i}" for i in steps] )) if 0 <= current_idx < len(points): fig.add_trace(go.Scatter( x=[points[current_idx, 0]], y=[points[current_idx, 1]], mode="markers+text", text=[f"step {current_idx}"], textposition="top center", marker=dict(size=14, color="red"), name="Current step" )) fig.update_layout( title="Latent PCA trajectory over steps", xaxis_title="PC1", yaxis_title="PC2", height=400 ) return fig def make_norm_figure(norms, current_idx): """Plot latent norm vs step, highlighting the current step.""" steps = list(range(len(norms))) fig = go.Figure() fig.add_trace(go.Scatter( x=steps, y=norms, mode="lines+markers", name="Latent norm" )) if 0 <= current_idx < len(norms): fig.add_trace(go.Scatter( x=[steps[current_idx]], y=[norms[current_idx]], mode="markers", marker=dict(size=14, color="red"), name="Current step" )) fig.update_layout( title="Latent L2 norm vs diffusion step", xaxis_title="Step index (0 = most noisy)", yaxis_title="β€–latentβ€–β‚‚", height=400 ) return fig # -------------------- MAIN ANALYSIS FUNCTION -------------------- # def run_diffusion_analysis(prompt, num_steps, guidance, seed, simple_mode): """ Run the tiny diffusion pipeline, recording latents at each step. Returns Gradio updates + a state dict. """ if not prompt or not prompt.strip(): return ( None, # final image f"⚠️ Please enter a prompt.", gr.update(maximum=0, value=0), None, None, None, { "error": "no_prompt" } ) pipe = get_pipe() num_steps = int(num_steps) guidance = float(guidance) # Seed handling if seed is None or seed < 0: generator = torch.Generator(device=DEVICE) else: generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) latents_list = [] timesteps_list = [] def callback(step, timestep, latents): # latents: (batch, C, H, W) latents_list.append(latents.detach().cpu().numpy()[0]) timesteps_list.append(int(timestep)) t0 = time.time() try: result = pipe( prompt, num_inference_steps=num_steps, guidance_scale=guidance, generator=generator, callback=callback, callback_steps=1, ) except Exception as e: return ( None, f"❌ Model / diffusion error: {e}", gr.update(maximum=0, value=0), None, None, None, { "error": "diffusion_error", "details": str(e) } ) elapsed = time.time() - t0 if len(latents_list) == 0: return ( None, "❌ No latents were collected. Something went wrong inside the pipeline.", gr.update(maximum=0, value=0), None, None, None, { "error": "no_latents" } ) final_image = result.images[0] # PIL # Compute PCA and norms over steps pca_points = compute_pca_over_steps(latents_list) norms = compute_norms_over_steps(latents_list) # Default step: last (most denoised) current_idx = len(latents_list) - 1 # Decode image for current step try: step_image = decode_latent_to_pil(pipe, latents_list[current_idx]) except Exception: step_image = None # Build plots pca_fig = make_pca_figure(pca_points, current_idx) if pca_points is not None else None norm_fig = make_norm_figure(norms, current_idx) if norms else None # Explanation explanation = explain(simple_mode) explanation += f"\n\n⏱ **Runtime:** {elapsed:.2f}s β€’ **Steps:** {len(latents_list)}" # State dict to keep everything for slider updates state = { "prompt": prompt, "num_steps": num_steps, "guidance": guidance, "seed": seed, "latents": latents_list, "timesteps": timesteps_list, "pca_points": pca_points, "norms": norms } step_slider_update = gr.update(maximum=len(latents_list)-1, value=current_idx) return ( final_image, explanation, step_slider_update, step_image, pca_fig, norm_fig, state ) def update_step_view(state, step_idx): """ When the user moves the step slider, update: - the decoded image at that step - the PCA plot (highlight current) - the norm plot (highlight current) """ if not state or "latents" not in state: return gr.update(value=None), gr.update(value=None), gr.update(value=None) latents_list = state["latents"] pca_points = state["pca_points"] norms = state["norms"] if len(latents_list) == 0: return gr.update(value=None), gr.update(value=None), gr.update(value=None) step_idx = int(step_idx) step_idx = max(0, min(step_idx, len(latents_list) - 1)) pipe = get_pipe() # Decode image at this step try: step_image = decode_latent_to_pil(pipe, latents_list[step_idx]) except Exception: step_image = None # Update PCA & norm plots pca_fig = make_pca_figure(pca_points, step_idx) if pca_points is not None else None norm_fig = make_norm_figure(norms, step_idx) if norms else None return gr.update(value=step_image), gr.update(value=pca_fig), gr.update(value=norm_fig) # -------------------- GRADIO UI -------------------- # with gr.Blocks(title="Diffusion Visualizer β€” Noise to Image", theme=gr.themes.Soft()) as demo: gr.Markdown("# 🧠 Image Diffusion Visualizer (Advanced)") gr.Markdown( "See how a tiny Stable Diffusion model turns **pure noise** into an image " "step by step. Use the slider to move through the diffusion process." ) with gr.Row(): with gr.Column(scale=2): prompt_box = gr.Textbox( label="Prompt", value="a small house in the forest, digital art", lines=3 ) num_steps_slider = gr.Slider( minimum=5, maximum=50, value=20, step=1, label="Number of diffusion steps" ) guidance_slider = gr.Slider( minimum=1.0, maximum=10.0, value=7.5, step=0.5, label="Guidance scale (higher = follow prompt more)" ) seed_box = gr.Number( label="Seed (leave -1 for random)", value=-1, precision=0 ) simple_mode_chk = gr.Checkbox( label="Explain in simple terms (for kids/elders)", value=True ) run_btn = gr.Button("Generate & Analyze", variant="primary") with gr.Column(scale=2): final_image = gr.Image(label="Final generated image") explanation_md = gr.Markdown(label="Explanation") gr.Markdown("### πŸ” Explore the denoising process") step_slider = gr.Slider( minimum=0, maximum=0, value=0, step=1, label="View step (0 = early, noisy β€’ max = late, clear)" ) with gr.Row(): with gr.Column(): step_image = gr.Image(label="Image at this diffusion step") with gr.Column(): pca_plot = gr.Plot(label="Latent PCA trajectory") with gr.Column(): norm_plot = gr.Plot(label="Latent norm vs step") state = gr.State() # Wire run button run_btn.click( run_diffusion_analysis, inputs=[prompt_box, num_steps_slider, guidance_slider, seed_box, simple_mode_chk], outputs=[final_image, explanation_md, step_slider, step_image, pca_plot, norm_plot, state] ) # Wire slider change step_slider.change( update_step_view, inputs=[state, step_slider], outputs=[step_image, pca_plot, norm_plot] ) demo.launch()