PraneshJs's picture
Update app.py
181be6c verified
raw
history blame
12.3 kB
# IMAGE DIFFUSION VISUALIZER β€” ADVANCED
# Visualizes how a (tiny) Stable Diffusion model denoises step by step.
# Model: hf-internal-testing/tiny-stable-diffusion-pipe (small, CPU-safe, for demos)
import gradio as gr
import torch
import numpy as np
from diffusers import DiffusionPipeline
from sklearn.decomposition import PCA
import plotly.graph_objects as go
import plotly.express as px
from PIL import Image
import time
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_ID = "hf-internal-testing/tiny-stable-diffusion-pipe"
PIPE_CACHE = None
# -------------------- MODEL LOADING -------------------- #
def get_pipe():
"""Lazy-load and cache the tiny Stable Diffusion pipeline."""
global PIPE_CACHE
if PIPE_CACHE is not None:
return PIPE_CACHE
pipe = DiffusionPipeline.from_pretrained(MODEL_ID)
pipe.to(DEVICE)
pipe.safety_checker = None # tiny pipe usually doesn't have NSFW issues; keep simple
PIPE_CACHE = pipe
return PIPE_CACHE
# -------------------- CORE UTILS -------------------- #
def decode_latent_to_pil(pipe, latent_np):
"""
Decode a latent (C,H,W) numpy array to a PIL image using the VAE.
Works for intermediate steps too.
"""
vae = pipe.vae
latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
# scaling_factor is used in SD-style VAEs; fallback to standard SD value
scale = getattr(vae.config, "scaling_factor", 0.18215)
with torch.no_grad():
image = vae.decode(latent / scale).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image[0].permute(1, 2, 0).cpu().numpy()
image = (image * 255).astype("uint8")
return Image.fromarray(image)
def compute_pca_over_steps(latents_list):
"""
latents_list: list of (C,H,W) numpy arrays.
Flatten each into a single vector; run PCA across steps.
Returns (S,2) array of 2D coords.
"""
if len(latents_list) == 0:
return None
flat = [x.reshape(-1) for x in latents_list]
mat = np.stack(flat, axis=0) # (steps, dim)
if mat.shape[0] < 2 or mat.shape[1] < 2:
# Not enough data for PCA; return zeros
return np.zeros((mat.shape[0], 2))
try:
pca = PCA(n_components=2)
pts = pca.fit_transform(mat)
return pts
except Exception:
return np.zeros((mat.shape[0], 2))
def compute_norms_over_steps(latents_list):
"""Compute L2 norm of each latent across channels & spatial dims."""
if len(latents_list) == 0:
return []
flat = [x.reshape(-1) for x in latents_list]
norms = [float(np.linalg.norm(v)) for v in flat]
return norms
def explain(simple=True):
if simple:
return (
"πŸ§’ **Simple explanation of what you see:**\n\n"
"1. The model starts with a totally noisy image.\n"
"2. Step by step, it removes noise and shapes the picture.\n"
"3. Your words (the prompt) tell it *what* to draw.\n"
"4. The slider lets you move through these steps:\n"
" - Early steps = mostly noise\n"
" - Later steps = clearer image\n"
)
else:
return (
"πŸ”¬ **Technical explanation:**\n\n"
"- We use a tiny Stable Diffusion-style pipeline.\n"
"- At each timestep `t`, the UNet predicts noise Ξ΅β‚œ for latent `zβ‚œ`.\n"
"- The scheduler updates `zβ‚œ β†’ zβ‚œβ‚‹β‚` using Ξ΅β‚œ.\n"
"- We record the latent after each step and decode it with the VAE.\n"
"- PCA over flattened latents shows the trajectory in latent space.\n"
"- Latent norm vs step shows how the magnitude evolves during denoising.\n"
)
def make_pca_figure(points, current_idx):
"""Make a PCA trajectory plot over steps, highlighting the selected step."""
steps = list(range(len(points)))
fig = go.Figure()
fig.add_trace(go.Scatter(
x=points[:, 0],
y=points[:, 1],
mode="lines+markers",
name="Steps",
text=[f"step {i}" for i in steps]
))
if 0 <= current_idx < len(points):
fig.add_trace(go.Scatter(
x=[points[current_idx, 0]],
y=[points[current_idx, 1]],
mode="markers+text",
text=[f"step {current_idx}"],
textposition="top center",
marker=dict(size=14, color="red"),
name="Current step"
))
fig.update_layout(
title="Latent PCA trajectory over steps",
xaxis_title="PC1",
yaxis_title="PC2",
height=400
)
return fig
def make_norm_figure(norms, current_idx):
"""Plot latent norm vs step, highlighting the current step."""
steps = list(range(len(norms)))
fig = go.Figure()
fig.add_trace(go.Scatter(
x=steps,
y=norms,
mode="lines+markers",
name="Latent norm"
))
if 0 <= current_idx < len(norms):
fig.add_trace(go.Scatter(
x=[steps[current_idx]],
y=[norms[current_idx]],
mode="markers",
marker=dict(size=14, color="red"),
name="Current step"
))
fig.update_layout(
title="Latent L2 norm vs diffusion step",
xaxis_title="Step index (0 = most noisy)",
yaxis_title="β€–latentβ€–β‚‚",
height=400
)
return fig
# -------------------- MAIN ANALYSIS FUNCTION -------------------- #
def run_diffusion_analysis(prompt, num_steps, guidance, seed, simple_mode):
"""
Run the tiny diffusion pipeline, recording latents at each step.
Returns Gradio updates + a state dict.
"""
if not prompt or not prompt.strip():
return (
None, # final image
f"⚠️ Please enter a prompt.",
gr.update(maximum=0, value=0),
None, None, None,
{
"error": "no_prompt"
}
)
pipe = get_pipe()
num_steps = int(num_steps)
guidance = float(guidance)
# Seed handling
if seed is None or seed < 0:
generator = torch.Generator(device=DEVICE)
else:
generator = torch.Generator(device=DEVICE).manual_seed(int(seed))
latents_list = []
timesteps_list = []
def callback(step, timestep, latents):
# latents: (batch, C, H, W)
latents_list.append(latents.detach().cpu().numpy()[0])
timesteps_list.append(int(timestep))
t0 = time.time()
try:
result = pipe(
prompt,
num_inference_steps=num_steps,
guidance_scale=guidance,
generator=generator,
callback=callback,
callback_steps=1,
)
except Exception as e:
return (
None,
f"❌ Model / diffusion error: {e}",
gr.update(maximum=0, value=0),
None, None, None,
{
"error": "diffusion_error",
"details": str(e)
}
)
elapsed = time.time() - t0
if len(latents_list) == 0:
return (
None,
"❌ No latents were collected. Something went wrong inside the pipeline.",
gr.update(maximum=0, value=0),
None, None, None,
{
"error": "no_latents"
}
)
final_image = result.images[0] # PIL
# Compute PCA and norms over steps
pca_points = compute_pca_over_steps(latents_list)
norms = compute_norms_over_steps(latents_list)
# Default step: last (most denoised)
current_idx = len(latents_list) - 1
# Decode image for current step
try:
step_image = decode_latent_to_pil(pipe, latents_list[current_idx])
except Exception:
step_image = None
# Build plots
pca_fig = make_pca_figure(pca_points, current_idx) if pca_points is not None else None
norm_fig = make_norm_figure(norms, current_idx) if norms else None
# Explanation
explanation = explain(simple_mode)
explanation += f"\n\n⏱ **Runtime:** {elapsed:.2f}s β€’ **Steps:** {len(latents_list)}"
# State dict to keep everything for slider updates
state = {
"prompt": prompt,
"num_steps": num_steps,
"guidance": guidance,
"seed": seed,
"latents": latents_list,
"timesteps": timesteps_list,
"pca_points": pca_points,
"norms": norms
}
step_slider_update = gr.update(maximum=len(latents_list)-1, value=current_idx)
return (
final_image,
explanation,
step_slider_update,
step_image,
pca_fig,
norm_fig,
state
)
def update_step_view(state, step_idx):
"""
When the user moves the step slider, update:
- the decoded image at that step
- the PCA plot (highlight current)
- the norm plot (highlight current)
"""
if not state or "latents" not in state:
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
latents_list = state["latents"]
pca_points = state["pca_points"]
norms = state["norms"]
if len(latents_list) == 0:
return gr.update(value=None), gr.update(value=None), gr.update(value=None)
step_idx = int(step_idx)
step_idx = max(0, min(step_idx, len(latents_list) - 1))
pipe = get_pipe()
# Decode image at this step
try:
step_image = decode_latent_to_pil(pipe, latents_list[step_idx])
except Exception:
step_image = None
# Update PCA & norm plots
pca_fig = make_pca_figure(pca_points, step_idx) if pca_points is not None else None
norm_fig = make_norm_figure(norms, step_idx) if norms else None
return gr.update(value=step_image), gr.update(value=pca_fig), gr.update(value=norm_fig)
# -------------------- GRADIO UI -------------------- #
with gr.Blocks(title="Diffusion Visualizer β€” Noise to Image") as demo:
gr.Markdown("# 🧠 Image Diffusion Visualizer (Advanced)")
gr.Markdown(
"See how a tiny Stable Diffusion model turns **pure noise** into an image "
"step by step. Use the slider to move through the diffusion process."
)
with gr.Row():
with gr.Column(scale=2):
prompt_box = gr.Textbox(
label="Prompt",
value="a small house in the forest, digital art",
lines=3
)
num_steps_slider = gr.Slider(
minimum=5, maximum=50, value=20, step=1,
label="Number of diffusion steps"
)
guidance_slider = gr.Slider(
minimum=1.0, maximum=10.0, value=7.5, step=0.5,
label="Guidance scale (higher = follow prompt more)"
)
seed_box = gr.Number(
label="Seed (leave -1 for random)",
value=-1,
precision=0
)
simple_mode_chk = gr.Checkbox(
label="Explain in simple terms (for kids/elders)",
value=True
)
run_btn = gr.Button("Generate & Analyze", variant="primary")
with gr.Column(scale=2):
final_image = gr.Image(label="Final generated image")
explanation_md = gr.Markdown(label="Explanation")
gr.Markdown("### πŸ” Explore the denoising process")
step_slider = gr.Slider(
minimum=0, maximum=0, value=0, step=1,
label="View step (0 = early, noisy β€’ max = late, clear)"
)
with gr.Row():
with gr.Column():
step_image = gr.Image(label="Image at this diffusion step")
with gr.Column():
pca_plot = gr.Plot(label="Latent PCA trajectory")
with gr.Column():
norm_plot = gr.Plot(label="Latent norm vs step")
state = gr.State()
# Wire run button
run_btn.click(
run_diffusion_analysis,
inputs=[prompt_box, num_steps_slider, guidance_slider, seed_box, simple_mode_chk],
outputs=[final_image, explanation_md, step_slider, step_image, pca_plot, norm_plot, state]
)
# Wire slider change
step_slider.change(
update_step_view,
inputs=[state, step_slider],
outputs=[step_image, pca_plot, norm_plot]
)
demo.launch()