Diffusers
Safetensors
How to use from the
Use from the
Diffusers library
pip install -U diffusers transformers accelerate
import torch
from diffusers import DiffusionPipeline

# switch to "mps" for apple devices
pipe = DiffusionPipeline.from_pretrained("harveymannering/mnist-ddpm", dtype=torch.bfloat16, device_map="cuda")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt).images[0]

Conditional DDPM on MNIST

This model is a straight forward implementation of method described in the Denoising Diffusion Probabilistic Models paper in the MNIST dataset. We use the linear noise schedule and train for 25 epochs on the entire training set. We condition on the digit labels from MNIST and also train the model to do unconditional generation (index 10) for 15% of the training steps. Below we show results from the training run including the MSE loss plot, and generation results with and without classifier free guidance.

Generation Results (Without Classifier Free Guidance)

image

Generation Results (With Classifier Free Guidance)

image

Training Loss

image

Example Code

Example self contained code is shown below. If this code stops working in the future please post your errors on the "Community" tab on this page:

import torch
import diffusers
import matplotlib.pyplot as plt

# Download diffusion model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = diffusers.UNet2DModel.from_pretrained("harveymannering/mnist-ddpm").to(device)

# Define noise schedule
beta_start = 0.0001
beta_end = 0.02
timesteps = 1000
betas = torch.linspace(beta_start, beta_end, timesteps).to(device)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)

# Define sampling code
@torch.no_grad()
def sample(x, net, labels=None, total_steps=50, w=1.0):

    # Generate random labels if non are provided
    if labels is None:
        labels = torch.randint(10, (x.shape[0],)).to(device)

    # Choose a non-Markovian (DDIM-style) schedule of indices to visit
    schedule = torch.linspace(0, timesteps - 1, total_steps, dtype=torch.long).to(device)

    # Run inference starting from random noise
    for idx in reversed(range(total_steps)):

        # Get the correct shape for timesteps t (from the schedule)
        t_val = schedule[idx]
        t = torch.full((x.shape[0],), t_val.item(), dtype=torch.long).to(device)

        # Copy tensors for CFG
        if w > 1.0:
            x_input = torch.concat([x, x], dim=0)
            labels_input = torch.concat([labels, torch.ones_like(labels) * 10], dim=0)
            t_input = torch.concat([t, t], dim=0)
        else:
            x_input = x
            t_input = t
            labels_input = labels

        # Run neural network
        predicted_noise = net(x_input, t_input, labels_input).sample

        # Perform classifier free guidance (CFG)
        if w > 1.0:
            predicted_noise_cond, predicted_noise_uncond = predicted_noise[:x.shape[0]], predicted_noise[x.shape[0]:]
            predicted_noise = w * predicted_noise_cond + (1 - w) * predicted_noise_uncond

        # Equation 12 - Denoising Diffusion Implicit Models (https://arxiv.org/pdf/2010.02502)
        alpha_cumprod = alphas_cumprod[t][:, None, None, None]
        if idx == 0:
            alpha_cumprod_minus_1 = torch.tensor(1.0, device=device)
        else:
            t_prev = torch.full((x.shape[0],), schedule[idx - 1].item(), dtype=torch.long).to(device)
            alpha_cumprod_minus_1 = alphas_cumprod[t_prev][:, None, None, None]
        pred_x0 = ((x - torch.sqrt(1 - alpha_cumprod) * predicted_noise) / torch.sqrt(alpha_cumprod))
        dir_to_xt = torch.sqrt(1 - alpha_cumprod_minus_1) * predicted_noise
        x = torch.sqrt(alpha_cumprod_minus_1) * pred_x0 + dir_to_xt

    return x

# Generate images
noise = torch.randn(5, 1, 28, 28).to(device).float()
labels = torch.randint(10, (5,)).to(device)
samples = sample(noise, model, labels, total_steps=20, w=3.0)

# Plot samples
fig, axes = plt.subplots(1,5, figsize=(10,2))
for i, ax in enumerate(axes):
    ax.imshow(samples[i,0].cpu().numpy(), cmap="gray")
    ax.axis("off")
plt.show()

The full training and inference code can be found at https://github.com/harveymannering/boilerplate_code/blob/main/ddpm.ipynb.

Downloads last month
21
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for harveymannering/mnist-ddpm