Spaces:
Sleeping
Sleeping
| # ========================================================== | |
| # Stable Diffusion v1-4 — CPU Diffusion Visualizer (256x256) | |
| # - Runs on HF CPU | |
| # - Real images (not blurry) | |
| # - Step-by-step latents | |
| # - PCA trajectory + latent norm plots | |
| # ========================================================== | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from diffusers import StableDiffusionPipeline, DDIMScheduler | |
| from sklearn.decomposition import PCA | |
| import plotly.graph_objects as go | |
| from PIL import Image | |
| import time | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ------------------- CPU SETTINGS ------------------- | |
| DEVICE = "cpu" | |
| # Sometimes MKLDNN causes weird matmul errors with SD on some CPUs, disable to be safe. | |
| torch.backends.mkldnn.enabled = False | |
| MODEL_ID = "CompVis/stable-diffusion-v1-4" | |
| PIPE_CACHE = None | |
| # ------------------- LOAD SD MODEL ------------------- | |
| def get_pipe(): | |
| """ | |
| Load and cache the Stable Diffusion v1-4 pipeline on CPU, | |
| with safety checker DISABLED correctly. | |
| """ | |
| global PIPE_CACHE | |
| if PIPE_CACHE is not None: | |
| return PIPE_CACHE | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| safety_checker=None, # <--- disable safety checker properly | |
| requires_safety_checker=False | |
| ) | |
| # Use DDIM so we have clear, predictable timesteps for visualization | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(DEVICE) | |
| PIPE_CACHE = pipe | |
| return PIPE_CACHE | |
| # ------------------- PCA + NORM ------------------- | |
| def compute_pca(latents): | |
| """ | |
| latents: list of (C,H,W) numpy arrays. | |
| Returns Nx2 array of PCA coords (one point per step). | |
| """ | |
| if not latents: | |
| return np.zeros((0, 2)) | |
| flat = [x.flatten() for x in latents] | |
| X = np.stack(flat) | |
| if X.shape[0] < 2: | |
| return np.zeros((X.shape[0], 2)) | |
| try: | |
| pca = PCA(n_components=2) | |
| pts = pca.fit_transform(X) | |
| return pts | |
| except Exception: | |
| return np.zeros((X.shape[0], 2)) | |
| def compute_norm(latents): | |
| """ | |
| L2 norm of each latent over all dims. | |
| """ | |
| if not latents: | |
| return [] | |
| return [float(np.linalg.norm(x.flatten())) for x in latents] | |
| # ------------------- LATENT DECODER ------------------- | |
| def decode_latent(pipe, latent_np): | |
| """ | |
| Decode a single latent (C,H,W) numpy array into a 256x256 RGB PIL image. | |
| """ | |
| latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) | |
| scale = pipe.vae.config.scaling_factor | |
| with torch.no_grad(): | |
| image = pipe.vae.decode(latent / scale).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") | |
| return Image.fromarray(np_img) | |
| # ------------------- MAIN DIFFUSION RUN ------------------- | |
| def run_diffusion(prompt, steps, guidance, seed, simple): | |
| """ | |
| Run SD v1-4 at 256x256, capturing latents at EVERY step via callback. | |
| Returns: | |
| - final image | |
| - explanation text | |
| - step slider config | |
| - image at current step | |
| - PCA plot | |
| - norm plot | |
| - state dict (for slider updates) | |
| """ | |
| if not prompt or not prompt.strip(): | |
| return ( | |
| None, | |
| "⚠️ Please enter a prompt.", | |
| gr.update(maximum=0, value=0), | |
| None, | |
| None, | |
| None, | |
| {} | |
| ) | |
| pipe = get_pipe() | |
| steps = int(steps) | |
| guidance = float(guidance) | |
| if seed is None or seed < 0: | |
| seed_val = int(time.time()) | |
| else: | |
| seed_val = int(seed) | |
| generator = torch.Generator(device=DEVICE).manual_seed(seed_val) | |
| latents_list = [] | |
| timesteps = [] | |
| def callback(step: int, timestep: int, latents: torch.FloatTensor): | |
| # latents shape: (batch, C, H, W) | |
| latents_list.append(latents.detach().cpu().numpy()[0]) | |
| timesteps.append(int(timestep)) | |
| t0 = time.time() | |
| try: | |
| result = pipe( | |
| prompt, | |
| height=256, | |
| width=256, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| generator=generator, | |
| callback=callback, | |
| callback_steps=1, | |
| ) | |
| except Exception as e: | |
| return ( | |
| None, | |
| f"❌ Diffusion error: {e}", | |
| gr.update(maximum=0, value=0), | |
| None, | |
| None, | |
| None, | |
| {"error": str(e)} | |
| ) | |
| total = time.time() - t0 | |
| if not latents_list: | |
| return ( | |
| None, | |
| "❌ No latents collected. Something went wrong inside the pipeline.", | |
| gr.update(maximum=0, value=0), | |
| None, | |
| None, | |
| None, | |
| {"error": "no_latents"} | |
| ) | |
| final_image = result.images[0] # PIL | |
| # Compute PCA trajectory and norms | |
| pca_pts = compute_pca(latents_list) | |
| norms = compute_norm(latents_list) | |
| current_idx = len(latents_list) - 1 # final step | |
| # Decode image at current step | |
| try: | |
| step_image = decode_latent(pipe, latents_list[current_idx]) | |
| except Exception: | |
| step_image = None | |
| # Explanation text | |
| if simple: | |
| explanation = ( | |
| "🧒 **Simple explanation of what you see:**\n\n" | |
| "1. The model starts from pure noise.\n" | |
| "2. At each step, it removes some noise and makes the picture clearer.\n" | |
| "3. Your text prompt tells it what kind of picture to create.\n" | |
| "4. You can move the slider to see the image at different steps.\n" | |
| ) | |
| else: | |
| explanation = ( | |
| "🔬 **Technical explanation:**\n\n" | |
| "- We run a DDIM diffusion process over the latent space.\n" | |
| "- At each timestep `t`, the UNet predicts noise εₜ and the scheduler updates `zₜ → zₜ₋₁`.\n" | |
| "- We record `zₜ` at every step and decode it with the VAE.\n" | |
| "- PCA over flattened latents gives a 2D trajectory of the diffusion path.\n" | |
| "- The L2 norm plot shows how the latent magnitude evolves per step.\n" | |
| ) | |
| explanation += f"\n⏱ **Runtime:** {total:.2f}s • **Steps:** {len(latents_list)} • Seed: {seed_val}" | |
| # Build plots | |
| pca_fig = plot_pca(pca_pts, current_idx) if len(pca_pts) > 0 else None | |
| norm_fig = plot_norm(norms, current_idx) if norms else None | |
| # State for slider updates | |
| state = { | |
| "latents": latents_list, | |
| "pca": pca_pts, | |
| "norms": norms | |
| } | |
| step_slider_update = gr.update(maximum=len(latents_list) - 1, value=current_idx) | |
| return ( | |
| final_image, | |
| explanation, | |
| step_slider_update, | |
| step_image, | |
| pca_fig, | |
| norm_fig, | |
| state | |
| ) | |
| # ------------------- PLOT HELPERS ------------------- | |
| def plot_pca(points, idx): | |
| """ | |
| PCA trajectory plot over steps, highlighting current step. | |
| points: (N,2) | |
| """ | |
| if points.shape[0] == 0: | |
| return None | |
| steps = list(range(points.shape[0])) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=points[:, 0], | |
| y=points[:, 1], | |
| mode="lines+markers", | |
| name="steps", | |
| text=[f"step {i}" for i in steps] | |
| )) | |
| if 0 <= idx < len(steps): | |
| fig.add_trace(go.Scatter( | |
| x=[points[idx, 0]], | |
| y=[points[idx, 1]], | |
| mode="markers+text", | |
| text=[f"step {idx}"], | |
| textposition="top center", | |
| marker=dict(size=12, color="red"), | |
| name="current" | |
| )) | |
| fig.update_layout( | |
| title="Latent PCA trajectory", | |
| xaxis_title="PC1", | |
| yaxis_title="PC2", | |
| height=350 | |
| ) | |
| return fig | |
| def plot_norm(norms, idx): | |
| """ | |
| Plot latent L2 norm vs step, highlight current step. | |
| """ | |
| if not norms: | |
| return None | |
| steps = list(range(len(norms))) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=steps, | |
| y=norms, | |
| mode="lines+markers", | |
| name="‖latent‖₂" | |
| )) | |
| if 0 <= idx < len(steps): | |
| fig.add_trace(go.Scatter( | |
| x=[idx], | |
| y=[norms[idx]], | |
| mode="markers", | |
| marker=dict(size=12, color="red"), | |
| name="current" | |
| )) | |
| fig.update_layout( | |
| title="Latent L2 norm vs step", | |
| xaxis_title="Step index", | |
| yaxis_title="‖latent‖₂", | |
| height=350 | |
| ) | |
| return fig | |
| # ------------------- SLIDER UPDATE ------------------- | |
| def update_step(state, idx): | |
| """ | |
| When user moves the slider: | |
| - decode latent at that step | |
| - update PCA highlight | |
| - update norm highlight | |
| """ | |
| if not state or "latents" not in state: | |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| latents = state["latents"] | |
| pca_pts = state["pca"] | |
| norms = state["norms"] | |
| if not latents: | |
| return gr.update(value=None), gr.update(value=None), gr.update(value=None) | |
| idx = int(idx) | |
| idx = max(0, min(idx, len(latents) - 1)) | |
| pipe = get_pipe() | |
| try: | |
| img = decode_latent(pipe, latents[idx]) | |
| except Exception: | |
| img = None | |
| pca_fig = plot_pca(pca_pts, idx) if pca_pts is not None else None | |
| norm_fig = plot_norm(norms, idx) if norms else None | |
| return gr.update(value=img), gr.update(value=pca_fig), gr.update(value=norm_fig) | |
| # ------------------- GRADIO UI ------------------- | |
| with gr.Blocks(title="Stable Diffusion v1-4 — CPU Diffusion Visualizer",theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)") | |
| gr.Markdown( | |
| "This app shows **how a real Stable Diffusion model** turns noise into an image, step by step.\n" | |
| "- Uses `CompVis/stable-diffusion-v1-4` on CPU\n" | |
| "- 256×256 resolution for speed\n" | |
| "- You can scrub through all diffusion steps\n" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value="a small cozy cabin in the forest, digital art", | |
| lines=3 | |
| ) | |
| steps = gr.Slider(10, 30, value=20, step=1, label="Number of diffusion steps") | |
| guidance = gr.Slider(1.0, 12.0, value=7.5, step=0.5, label="Guidance scale") | |
| seed = gr.Number(label="Seed (-1 for random)", value=-1, precision=0) | |
| simple = gr.Checkbox(label="Simple explanation", value=True) | |
| run = gr.Button("Run diffusion", variant="primary") | |
| with gr.Column(): | |
| final = gr.Image(label="Final generated image") | |
| expl = gr.Markdown(label="Explanation") | |
| gr.Markdown("### 🔍 Explore the denoising process step-by-step") | |
| step_slider = gr.Slider(0, 0, value=0, step=1, label="View step (0 = early noise, max = final)") | |
| step_img = gr.Image(label="Image at this diffusion step") | |
| pca_plot = gr.Plot(label="Latent PCA trajectory") | |
| norm_plot = gr.Plot(label="Latent norm vs step") | |
| state = gr.State() | |
| run.click( | |
| run_diffusion, | |
| inputs=[prompt, steps, guidance, seed, simple], | |
| outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state] | |
| ) | |
| step_slider.change( | |
| update_step, | |
| inputs=[state, step_slider], | |
| outputs=[step_img, pca_plot, norm_plot] | |
| ) | |
| demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, pwa=True) |