Spaces:
Sleeping
Sleeping
| # ========================================================== | |
| # Stable Diffusion v1-4 — CPU Optimized Diffusion Visualizer | |
| # REAL images (256×256) on free HuggingFace CPU | |
| # With: step-by-step latents, PCA path, norm plot, latents decode | |
| # ========================================================== | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from diffusers import StableDiffusionPipeline, DDIMScheduler | |
| from sklearn.decomposition import PCA | |
| import plotly.graph_objects as go | |
| from PIL import Image | |
| import time | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| # ------------------- CPU SETTINGS ------------------- | |
| DEVICE = "cpu" | |
| # Disable MKLDNN for safety (prevents matmul errors on SD) | |
| torch.backends.mkldnn.enabled = False | |
| MODEL_ID = "CompVis/stable-diffusion-v1-4" | |
| PIPE_CACHE = None | |
| # ------------------- LOAD SD MODEL ------------------- | |
| def get_pipe(): | |
| global PIPE_CACHE | |
| if PIPE_CACHE: | |
| return PIPE_CACHE | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float32, | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Replace scheduler with DDIM (better for stepping) | |
| pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
| pipe.to(DEVICE) | |
| # VERY IMPORTANT: disable safety checker to avoid weird errors on CPU | |
| pipe.safety_checker = lambda images, clip_input: (images, False) | |
| # Disable features not needed | |
| pipe.enable_attention_slicing(None) | |
| PIPE_CACHE = pipe | |
| return PIPE_CACHE | |
| # ------------------- PCA + NORM ------------------- | |
| def compute_pca(latents): | |
| flat = [x.flatten() for x in latents] | |
| X = np.stack(flat) | |
| if X.shape[0] < 2: | |
| return np.zeros((X.shape[0], 2)) | |
| try: | |
| pca = PCA(n_components=2) | |
| pts = pca.fit_transform(X) | |
| return pts | |
| except: | |
| return np.zeros((X.shape[0], 2)) | |
| def compute_norm(latents): | |
| return [float(np.linalg.norm(x.flatten())) for x in latents] | |
| # ------------------- LATENT DECODER ------------------- | |
| def decode_latent(pipe, latent_np): | |
| latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE) | |
| scale = pipe.vae.config.scaling_factor | |
| with torch.no_grad(): | |
| image = pipe.vae.decode(latent / scale).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8") | |
| return Image.fromarray(np_img) | |
| # ------------------- RUN DIFFUSION ------------------- | |
| def run_diffusion(prompt, steps, guidance, seed, simple): | |
| if not prompt.strip(): | |
| return None, "Enter prompt", gr.update(), None, None, None, {} | |
| pipe = get_pipe() | |
| generator = torch.Generator("cpu").manual_seed(seed if seed >= 0 else int(time.time())) | |
| latents_list = [] | |
| timesteps = [] | |
| def cb(step, t, latents): | |
| latents_list.append(latents.detach().cpu().numpy()[0]) | |
| timesteps.append(int(t)) | |
| t0 = time.time() | |
| result = pipe( | |
| prompt, | |
| height=256, | |
| width=256, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance, | |
| generator=generator, | |
| callback=cb, | |
| callback_steps=1, | |
| ) | |
| total = time.time() - t0 | |
| final = result.images[0] | |
| pca = compute_pca(latents_list) | |
| norms = compute_norm(latents_list) | |
| cur = len(latents_list) - 1 | |
| step_image = decode_latent(pipe, latents_list[cur]) | |
| explanation = ( | |
| "🧒 **Simple Explanation**\n" | |
| "The model starts with noise, slowly removes it, and reveals an image.\n" | |
| if simple else | |
| "🔬 **Technical Explanation**\n" | |
| "We collect latents at each DDIM step, decode them via VAE, and visualize their PCA path." | |
| ) | |
| explanation += f"\n⏱ Runtime: {total:.2f}s" | |
| state = { | |
| "latents": latents_list, | |
| "pca": pca, | |
| "norms": norms | |
| } | |
| return ( | |
| final, | |
| explanation, | |
| gr.update(maximum=len(latents_list)-1, value=cur), | |
| step_image, | |
| plot_pca(pca, cur), | |
| plot_norm(norms, cur), | |
| state | |
| ) | |
| # ------------------- PLOT FUNCTIONS ------------------- | |
| def plot_pca(points, idx): | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=points[:,0], y=points[:,1], mode="lines+markers")) | |
| fig.add_trace(go.Scatter( | |
| x=[points[idx,0]], y=[points[idx,1]], | |
| mode="markers", marker=dict(size=12, color="red") | |
| )) | |
| fig.update_layout(height=350, title="PCA Trajectory") | |
| return fig | |
| def plot_norm(norms, idx): | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(y=norms, mode="lines+markers")) | |
| fig.add_trace(go.Scatter( | |
| x=[idx], y=[norms[idx]], mode="markers", marker=dict(size=12, color="red") | |
| )) | |
| fig.update_layout(height=350, title="Latent Norm Over Steps") | |
| return fig | |
| # ------------------- SLIDER UPDATE ------------------- | |
| def update_step(state, idx): | |
| latents = state["latents"] | |
| pca = state["pca"] | |
| norms = state["norms"] | |
| pipe = get_pipe() | |
| img = decode_latent(pipe, latents[idx]) | |
| return ( | |
| img, | |
| plot_pca(pca, idx), | |
| plot_norm(norms, idx) | |
| ) | |
| # ------------------- UI ------------------- | |
| with gr.Blocks(title="SD v1-4 CPU Diffusion Visualizer") as demo: | |
| gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)") | |
| gr.Markdown("This version produces **real images**, optimized for free HF CPU.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Textbox(label="Prompt", value="a cute cat in watercolor") | |
| steps = gr.Slider(10, 30, value=20, step=1, label="Steps") | |
| guidance = gr.Slider(3, 12, value=7.5, step=0.5, label="Guidance") | |
| seed = gr.Number(label="Seed (-1 for random)", value=-1) | |
| simple = gr.Checkbox(label="Simple Explanation", value=True) | |
| run = gr.Button("Run Diffusion", variant="primary") | |
| with gr.Column(): | |
| final = gr.Image(label="Final Image") | |
| expl = gr.Markdown() | |
| step_slider = gr.Slider(0, 0, value=0, step=1, label="View Step") | |
| step_img = gr.Image(label="Latent Image at Step") | |
| pca_plot = gr.Plot(label="PCA") | |
| norm_plot = gr.Plot(label="Norm Plot") | |
| state = gr.State() | |
| run.click( | |
| run_diffusion, | |
| inputs=[prompt, steps, guidance, seed, simple], | |
| outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state] | |
| ) | |
| step_slider.change(update_step, [state, step_slider], [step_img, pca_plot, norm_plot]) | |
| demo.launch() |