PraneshJs's picture
Update app.py
4a4fecc verified
raw
history blame
6.48 kB
# ==========================================================
# Stable Diffusion v1-4 — CPU Optimized Diffusion Visualizer
# REAL images (256×256) on free HuggingFace CPU
# With: step-by-step latents, PCA path, norm plot, latents decode
# ==========================================================
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionPipeline, DDIMScheduler
from sklearn.decomposition import PCA
import plotly.graph_objects as go
from PIL import Image
import time
import warnings
warnings.filterwarnings("ignore")
# ------------------- CPU SETTINGS -------------------
DEVICE = "cpu"
# Disable MKLDNN for safety (prevents matmul errors on SD)
torch.backends.mkldnn.enabled = False
MODEL_ID = "CompVis/stable-diffusion-v1-4"
PIPE_CACHE = None
# ------------------- LOAD SD MODEL -------------------
def get_pipe():
global PIPE_CACHE
if PIPE_CACHE:
return PIPE_CACHE
pipe = StableDiffusionPipeline.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True,
)
# Replace scheduler with DDIM (better for stepping)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.to(DEVICE)
# VERY IMPORTANT: disable safety checker to avoid weird errors on CPU
pipe.safety_checker = lambda images, clip_input: (images, False)
# Disable features not needed
pipe.enable_attention_slicing(None)
PIPE_CACHE = pipe
return PIPE_CACHE
# ------------------- PCA + NORM -------------------
def compute_pca(latents):
flat = [x.flatten() for x in latents]
X = np.stack(flat)
if X.shape[0] < 2:
return np.zeros((X.shape[0], 2))
try:
pca = PCA(n_components=2)
pts = pca.fit_transform(X)
return pts
except:
return np.zeros((X.shape[0], 2))
def compute_norm(latents):
return [float(np.linalg.norm(x.flatten())) for x in latents]
# ------------------- LATENT DECODER -------------------
def decode_latent(pipe, latent_np):
latent = torch.from_numpy(latent_np).unsqueeze(0).to(DEVICE)
scale = pipe.vae.config.scaling_factor
with torch.no_grad():
image = pipe.vae.decode(latent / scale).sample
image = (image / 2 + 0.5).clamp(0, 1)
np_img = (image[0].permute(1, 2, 0).cpu().numpy() * 255).astype("uint8")
return Image.fromarray(np_img)
# ------------------- RUN DIFFUSION -------------------
def run_diffusion(prompt, steps, guidance, seed, simple):
if not prompt.strip():
return None, "Enter prompt", gr.update(), None, None, None, {}
pipe = get_pipe()
generator = torch.Generator("cpu").manual_seed(seed if seed >= 0 else int(time.time()))
latents_list = []
timesteps = []
def cb(step, t, latents):
latents_list.append(latents.detach().cpu().numpy()[0])
timesteps.append(int(t))
t0 = time.time()
result = pipe(
prompt,
height=256,
width=256,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
callback=cb,
callback_steps=1,
)
total = time.time() - t0
final = result.images[0]
pca = compute_pca(latents_list)
norms = compute_norm(latents_list)
cur = len(latents_list) - 1
step_image = decode_latent(pipe, latents_list[cur])
explanation = (
"🧒 **Simple Explanation**\n"
"The model starts with noise, slowly removes it, and reveals an image.\n"
if simple else
"🔬 **Technical Explanation**\n"
"We collect latents at each DDIM step, decode them via VAE, and visualize their PCA path."
)
explanation += f"\n⏱ Runtime: {total:.2f}s"
state = {
"latents": latents_list,
"pca": pca,
"norms": norms
}
return (
final,
explanation,
gr.update(maximum=len(latents_list)-1, value=cur),
step_image,
plot_pca(pca, cur),
plot_norm(norms, cur),
state
)
# ------------------- PLOT FUNCTIONS -------------------
def plot_pca(points, idx):
fig = go.Figure()
fig.add_trace(go.Scatter(x=points[:,0], y=points[:,1], mode="lines+markers"))
fig.add_trace(go.Scatter(
x=[points[idx,0]], y=[points[idx,1]],
mode="markers", marker=dict(size=12, color="red")
))
fig.update_layout(height=350, title="PCA Trajectory")
return fig
def plot_norm(norms, idx):
fig = go.Figure()
fig.add_trace(go.Scatter(y=norms, mode="lines+markers"))
fig.add_trace(go.Scatter(
x=[idx], y=[norms[idx]], mode="markers", marker=dict(size=12, color="red")
))
fig.update_layout(height=350, title="Latent Norm Over Steps")
return fig
# ------------------- SLIDER UPDATE -------------------
def update_step(state, idx):
latents = state["latents"]
pca = state["pca"]
norms = state["norms"]
pipe = get_pipe()
img = decode_latent(pipe, latents[idx])
return (
img,
plot_pca(pca, idx),
plot_norm(norms, idx)
)
# ------------------- UI -------------------
with gr.Blocks(title="SD v1-4 CPU Diffusion Visualizer") as demo:
gr.Markdown("# 🧠 Stable Diffusion v1-4 — CPU Visualizer (256×256)")
gr.Markdown("This version produces **real images**, optimized for free HF CPU.")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(label="Prompt", value="a cute cat in watercolor")
steps = gr.Slider(10, 30, value=20, step=1, label="Steps")
guidance = gr.Slider(3, 12, value=7.5, step=0.5, label="Guidance")
seed = gr.Number(label="Seed (-1 for random)", value=-1)
simple = gr.Checkbox(label="Simple Explanation", value=True)
run = gr.Button("Run Diffusion", variant="primary")
with gr.Column():
final = gr.Image(label="Final Image")
expl = gr.Markdown()
step_slider = gr.Slider(0, 0, value=0, step=1, label="View Step")
step_img = gr.Image(label="Latent Image at Step")
pca_plot = gr.Plot(label="PCA")
norm_plot = gr.Plot(label="Norm Plot")
state = gr.State()
run.click(
run_diffusion,
inputs=[prompt, steps, guidance, seed, simple],
outputs=[final, expl, step_slider, step_img, pca_plot, norm_plot, state]
)
step_slider.change(update_step, [state, step_slider], [step_img, pca_plot, norm_plot])
demo.launch()