Diffusion-CIFAR10 β€” Unconditional DDPM on CIFAR-10

A Denoising Diffusion Probabilistic Model (DDPM) trained on CIFAR-10 for unconditional image generation. The model generates 32x32 RGB images spanning all 10 CIFAR-10 categories without class conditioning.

Model Description

Implements DDPM (Ho et al., 2020) with a U-Net denoising network featuring self-attention blocks and sinusoidal time embeddings.

Architecture β€” U-Net

Time Embedding:

  • Sinusoidal timestep encoding β†’ two-layer FFN with Swish activation
  • Injected as an additive bias into every residual block

Residual Blocks (ResBlock):

  • Two 3x3 Conv2d with GroupNorm and Swish activation
  • Time embedding added between convolutions
  • 1x1 projection on skip path when input/output channels differ

Attention (AttnBlock):

  • Single-head self-attention (query, key, value projections)
  • Applied at the bottleneck and selected resolution levels

Downsampling: Strided Conv2d (stride 2) Upsampling: Nearest-neighbour interpolation + Conv2d

Diffusion Process

  • Parameterization: Noise prediction (epsilon)
  • Schedule: Linear beta schedule
  • Loss: MSE between predicted and actual noise

Training Details

Parameter Value
Dataset CIFAR-10 (50,000 images, 32x32 RGB)
Epochs 200 (checkpoint: ckpt_199.pth)
Optimizer AdamW, weight_decay=1e-4
LR Schedule Cosine annealing + linear warmup (GradualWarmupScheduler)
Gradient clipping Enabled
Data augmentation RandomHorizontalFlip, Normalize(0.5, 0.5, 0.5)
Workers 4

Repository Contents

File Description
model.py U-Net architecture (ResBlock, AttnBlock, DownSample, UpSample, TimeEmbedding)
diffusion.py GaussianDiffusionTrainer + GaussianDiffusionSampler
train.py Training loop
scheduler.py GradualWarmupScheduler
main.py Entry point
Checkpoints/ckpt_199.pth Final model checkpoint
SampledImgs/ Generated sample images

How to Use

import torch
from model import UNet
from diffusion import GaussianDiffusionSampler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNet(T=1000, ch=128, ch_mult=[1,2,2,2], attn=[1], num_res_blocks=2, dropout=0.1)
model.load_state_dict(torch.load("Checkpoints/ckpt_199.pth", map_location=device))
model.eval().to(device)

sampler = GaussianDiffusionSampler(beta_1=1e-4, beta_t=0.02, model=model, T=1000).to(device)

with torch.no_grad():
    x_T = torch.randn(16, 3, 32, 32, device=device)
    samples = sampler(x_T)

samples = (samples.clamp(-1, 1) + 1) / 2  # rescale to [0, 1]

References

License

MIT

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Dataset used to train YashNagraj75/Diffusion-CIFAR10

Paper for YashNagraj75/Diffusion-CIFAR10