Spaces:
Sleeping
Sleeping
| # IMAGE DIFFUSION VISUALIZER β ADVANCED | |
| # Visualizes how a (tiny) Stable Diffusion model denoises step by step. | |
| # Model: hf-internal-testing/tiny-stable-diffusion-pipe (small, CPU-safe, for demos) | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from diffusers import DiffusionPipeline | |
| from sklearn.decomposition import PCA | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| from PIL import Image | |
| import time | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-pipe" | |
| PIPE_CACHE = None | |
| # -------------------- MODEL LOADING -------------------- # | |
| def get_pipe(): | |
| """Lazy-load and cache the tiny Stable Diffusion pipeline.""" | |
| global PIPE_CACHE | |
| if PIPE_CACHE is not None: | |
| return PIPE_CACHE | |
| pipe = DiffusionPipeline.from_pretrained(MODEL_ID) | |
| pipe.to(DEVICE) | |
| pipe.safety_checker = None # tiny pipe usually doesn't have NSFW issues; keep simple | |
| PIPE_CACHE = pipe | |
| return PIPE_CACHE | |
| # -------------------- CORE UTILS -------------------- # | |
| def decode_latent_to_pil(pipe, latent_np): | |
| """ | |
| Decode a latent (C,H,W) numpy array to a PIL image using the VAE. | |
| Works for intermediate steps too. | |
| """ | |
| vae = pipe.vae | |
| latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) | |
| # scaling_factor is used in SD-style VAEs; fallback to standard SD value | |
| scale = getattr(vae.config, "scaling_factor", 0.18215) | |
| with torch.no_grad(): | |
| image = vae.decode(latent / scale).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image[0].permute(1, 2, 0).cpu().numpy() | |
| image = (image * 255).astype("uint8") | |
| return Image.fromarray(image) | |
| def compute_pca_over_steps(latents_list): | |
| """ | |
| latents_list: list of (C,H,W) numpy arrays. | |
| Flatten each into a single vector; run PCA across steps. | |
| Returns (S,2) array of 2D coords. | |
| """ | |
| if len(latents_list) == 0: | |
| return None | |
| flat = [x.reshape(-1) for x in latents_list] | |
| mat = np.stack(flat, axis=0) # (steps, dim) | |
| if mat.shape[0] < 2 or mat.shape[1] < 2: | |
| # Not enough data for PCA; return zeros | |
| return np.zeros((mat.shape[0], 2)) | |
| try: | |
| pca = PCA(n_components=2) | |
| pts = pca.fit_transform(mat) | |
| return pts | |
| except Exception: | |
| return np.zeros((mat.shape[0], 2)) | |
| def compute_norms_over_steps(latents_list): | |
| """Compute L2 norm of each latent across channels & spatial dims.""" | |
| if len(latents_list) == 0: | |
| return [] | |
| flat = [x.reshape(-1) for x in latents_list] | |
| norms = [float(np.linalg.norm(v)) for v in flat] | |
| return norms | |
| def explain(simple=True): | |
| if simple: | |
| return ( | |
| "π§ **Simple explanation of what you see:**\n\n" | |
| "1. The model starts with a totally noisy image.\n" | |
| "2. Step by step, it removes noise and shapes the picture.\n" | |
| "3. Your words (the prompt) tell it *what* to draw.\n" | |
| "4. The slider lets you move through these steps:\n" | |
| " - Early steps = mostly noise\n" | |
| " - Later steps = clearer image\n" | |
| ) | |
| else: | |
| return ( | |
| "π¬ **Technical explanation:**\n\n" | |
| "- We use a tiny Stable Diffusion-style pipeline.\n" | |
| "- At each timestep `t`, the UNet predicts noise Ξ΅β for latent `zβ`.\n" | |
| "- The scheduler updates `zβ β zβββ` using Ξ΅β.\n" | |
| "- We record the latent after each step and decode it with the VAE.\n" | |
| "- PCA over flattened latents shows the trajectory in latent space.\n" | |
| "- Latent norm vs step shows how the magnitude evolves during denoising.\n" | |
| ) | |
| def make_pca_figure(points, current_idx): | |
| """Make a PCA trajectory plot over steps, highlighting the selected step.""" | |
| steps = list(range(len(points))) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=points[:, 0], | |
| y=points[:, 1], | |
| mode="lines+markers", | |
| name="Steps", | |
| text=[f"step {i}" for i in steps] | |
| )) | |
| if 0 <= current_idx < len(points): | |
| fig.add_trace(go.Scatter( | |
| x=[points[current_idx, 0]], | |
| y=[points[current_idx, 1]], | |
| mode="markers+text", | |
| text=[f"step {current_idx}"], | |
| textposition="top center", | |
| marker=dict(size=14, color="red"), | |
| name="Current step" | |
| )) | |
| fig.update_layout( | |
| title="Latent PCA trajectory over steps", | |
| xaxis_title="PC1", | |
| yaxis_title="PC2", | |
| height=400 | |
| ) | |
| return fig | |
| def make_norm_figure(norms, current_idx): | |
| """Plot latent norm vs step, highlighting the current step.""" | |
| steps = list(range(len(norms))) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=steps, | |
| y=norms, | |
| mode="lines+markers", | |
| name="Latent norm" | |
| )) | |
| if 0 <= current_idx < len(norms): | |
| fig.add_trace(go.Scatter( | |
| x=[steps[current_idx]], | |
| y=[norms[current_idx]], | |
| mode="markers", | |
| marker=dict(size=14, color="red"), | |
| name="Current step" | |
| )) | |
| fig.update_layout( | |
| title="Latent L2 norm vs diffusion step", | |
| xaxis_title="Step index (0 = most noisy)", | |
| yaxis_title="βlatentββ", | |
| height=400 | |
| ) | |
| return fig | |
| # -------------------- MAIN ANALYSIS FUNCTION -------------------- # | |
| def run_diffusion_analysis(prompt, num_steps, guidance, seed, simple_mode): | |
| """ | |
| Run the tiny diffusion pipeline, recording latents at each step. | |
| Returns Gradio updates + a state dict. | |
| """ | |
| if not prompt or not prompt.strip(): | |
| return ( | |
| None, # final image | |
| f"β οΈ Please enter a prompt.", | |
| gr.update(maximum=0, value=0), | |
| None, None, None, | |
| { | |
| "error": "no_prompt" | |
| } | |
| ) | |
| pipe = get_pipe() | |
| num_steps = int(num_steps) | |
| guidance = float(guidance) | |
| # Seed handling | |
| if seed is None or seed < 0: | |
| generator = torch.Generator(device=DEVICE) | |
| else: | |
| generator = torch.Generator(device=DEVICE).manual_seed(int(seed)) | |
| latents_list = [] | |
| timesteps_list = [] | |
| def callback(step, timestep, latents): | |
| # latents: (batch, C, H, W) | |
| latents_list.append(latents.detach().cpu().numpy()[0]) | |
| timesteps_list.append(int(timestep)) | |
| t0 = time.time() | |
| try: | |
| result = pipe( | |
| prompt, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance, | |
| generator=generator, | |
| callback=callback, | |
| callback_steps=1, | |
| ) | |
| except Exception as e: | |
| return ( | |
| None, | |
| f"β Model / diffusion error: {e}", | |
| gr.update(maximum=0, value=0), | |
| None, None, None, | |
| { | |
| "error": "diffusion_error", | |
| "details": str(e) | |
| } | |
| ) | |
| elapsed = time.time() - t0 | |
| if len(latents_list) == 0: | |
| return ( | |
| None, | |
| "β No latents were collected. Something went wrong inside the pipeline.", | |
| gr.update(maximum=0, value=0), | |
| None, None, None, | |
| { | |
| "error": "no_latents" | |
| } | |
| ) | |
| final_image = result.images[0] # PIL | |
| # Compute PCA and norms over steps | |
| pca_points = compute_pca_over_steps(latents_list) | |
| norms = compute_norms_over_steps(latents_list) | |
| # Default step: last (most denoised) | |
| current_idx = len(latents_list) - 1 | |
| # Decode image for current step | |
| try: | |
| step_image = decode_latent_to_pil(pipe, latents_list[current_idx]) | |
| except Exception: | |
| step_image = None | |
| # Build plots | |
| pca_fig = make_pca_figure(pca_points, current_idx) if pca_points is not None else None | |
| norm_fig = make_norm_figure(norms, current_idx) if norms else None | |
| # Explanation | |
| explanation = explain(simple_mode) | |
| explanation += f"\n\nβ± **Runtime:** {elapsed:.2f}s β’ **Steps:** {len(latents_list)}" | |
| # State dict to keep everything for slider updates | |
| state = { | |
| "prompt": prompt, | |
| "num_steps": num_steps, | |
| "guidance": guidance, | |
| "seed": seed, | |
| "latents": latents_list, | |
| "timesteps": timesteps_list, | |
| "pca_points": pca_points, | |
| "norms": norms | |
| } | |
| step_slider_update = gr.update(maximum=len(latents_list)-1, value=current_idx) | |
| return ( | |
| final_image, | |
| explanation, | |
| step_slider_update, | |
| step_image, | |
| pca_fig, | |
| norm_fig, | |
| state | |
| ) | |
| def update_step_view(state, step_idx): | |
| """ | |
| When the user moves the step slider, update: | |
| - the decoded image at that step | |
| - the PCA plot (highlight current) | |
| - the norm plot (highlight current) | |
| """ | |
| if not state or "latents" not in state: | |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| latents_list = state["latents"] | |
| pca_points = state["pca_points"] | |
| norms = state["norms"] | |
| if len(latents_list) == 0: | |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| step_idx = int(step_idx) | |
| step_idx = max(0, min(step_idx, len(latents_list) - 1)) | |
| pipe = get_pipe() | |
| # Decode image at this step | |
| try: | |
| step_image = decode_latent_to_pil(pipe, latents_list[step_idx]) | |
| except Exception: | |
| step_image = None | |
| # Update PCA & norm plots | |
| pca_fig = make_pca_figure(pca_points, step_idx) if pca_points is not None else None | |
| norm_fig = make_norm_figure(norms, step_idx) if norms else None | |
| return gr.update(value=step_image), gr.update(value=pca_fig), gr.update(value=norm_fig) | |
| # -------------------- GRADIO UI -------------------- # | |
| with gr.Blocks(title="Diffusion Visualizer β Noise to Image", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π§ Image Diffusion Visualizer (Advanced)") | |
| gr.Markdown( | |
| "See how a tiny Stable Diffusion model turns **pure noise** into an image " | |
| "step by step. Use the slider to move through the diffusion process." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| prompt_box = gr.Textbox( | |
| label="Prompt", | |
| value="a small house in the forest, digital art", | |
| lines=3 | |
| ) | |
| num_steps_slider = gr.Slider( | |
| minimum=5, maximum=50, value=20, step=1, | |
| label="Number of diffusion steps" | |
| ) | |
| guidance_slider = gr.Slider( | |
| minimum=1.0, maximum=10.0, value=7.5, step=0.5, | |
| label="Guidance scale (higher = follow prompt more)" | |
| ) | |
| seed_box = gr.Number( | |
| label="Seed (leave -1 for random)", | |
| value=-1, | |
| precision=0 | |
| ) | |
| simple_mode_chk = gr.Checkbox( | |
| label="Explain in simple terms (for kids/elders)", | |
| value=True | |
| ) | |
| run_btn = gr.Button("Generate & Analyze", variant="primary") | |
| with gr.Column(scale=2): | |
| final_image = gr.Image(label="Final generated image") | |
| explanation_md = gr.Markdown(label="Explanation") | |
| gr.Markdown("### π Explore the denoising process") | |
| step_slider = gr.Slider( | |
| minimum=0, maximum=0, value=0, step=1, | |
| label="View step (0 = early, noisy β’ max = late, clear)" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| step_image = gr.Image(label="Image at this diffusion step") | |
| with gr.Column(): | |
| pca_plot = gr.Plot(label="Latent PCA trajectory") | |
| with gr.Column(): | |
| norm_plot = gr.Plot(label="Latent norm vs step") | |
| state = gr.State() | |
| # Wire run button | |
| run_btn.click( | |
| run_diffusion_analysis, | |
| inputs=[prompt_box, num_steps_slider, guidance_slider, seed_box, simple_mode_chk], | |
| outputs=[final_image, explanation_md, step_slider, step_image, pca_plot, norm_plot, state] | |
| ) | |
| # Wire slider change | |
| step_slider.change( | |
| update_step_view, | |
| inputs=[state, step_slider], | |
| outputs=[step_image, pca_plot, norm_plot] | |
| ) | |
| demo.launch() |