ddpm-cifar10 / diffusion.py
mlvar's picture
Upload folder using huggingface_hub
065eb11 verified
Raw
History Blame Contribute Delete
7.81 kB
"""Gaussian diffusion process — forward (noising) and reverse (sampling).
Implements the DDPM formulation (Ho et al. 2020):
- Linear β schedule
- ε-prediction objective
- DDPM ancestral sampling
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Noise schedules
# ---------------------------------------------------------------------------
def linear_beta_schedule(timesteps: int, start: float = 1e-4, end: float = 0.02) -> torch.Tensor:
"""Linear schedule: β_t grows linearly from `start` to `end`."""
return torch.linspace(start, end, timesteps)
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor:
"""Cosine schedule (Nichol & Dhariwal 2021).
Produces ᾱ_t following a cosine curve — gives more even noise-level
coverage than the linear schedule, which improves log-likelihood and
sample quality by spending more steps in the mid-noise regime where
the model learns the most about image structure.
Beta values grow naturally at high t (up to ~0.8 for the final steps).
This is by design: the corresponding sqrt_recip_alpha and coef_eps
coefficients compensate, keeping the reverse process mathematically
consistent. Clamping betas would break this consistency.
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1.0 - alphas_cumprod[1:] / alphas_cumprod[:-1]
return torch.clamp(betas, min=1e-5, max=0.999)
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor:
"""Gather values from 1-D tensor `a` at indices `t` and broadcast to `x_shape`."""
b = t.shape[0]
out = a[t.cpu()] # a is on CPU, t may be on CUDA — index on CPU
return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device)
# ---------------------------------------------------------------------------
# Gaussian Diffusion
# ---------------------------------------------------------------------------
class GaussianDiffusion:
"""DDPM forward/reverse process with pre-computed coefficients."""
def __init__(
self,
timesteps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 0.02,
loss_type: str = "l2",
schedule: str = "cosine",
):
self.timesteps = timesteps
# ── β schedule ──
if schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start, beta_end)
else:
raise ValueError(f"Unknown schedule: {schedule}")
# ── Precompute coefficients ──
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# Forward process
self.sqrt_alphas_cumprod = alphas_cumprod.sqrt()
self.sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod).sqrt()
# Reverse process — x₀ recovery (for clipping)
self.sqrt_recip_alphas_cumprod = (1.0 / alphas_cumprod).sqrt()
self.sqrt_recipm1_alphas_cumprod = (1.0 / alphas_cumprod - 1.0).sqrt()
# Reverse process — posterior mean via x₀ (Improved DDPM eq. 9)
self.posterior_mean_coef1 = (
betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod)
)
self.posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# Loss
self.loss_type = loss_type
# ------------------------------------------------------------------
# Forward diffusion
# ------------------------------------------------------------------
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""Sample x_t given x_0: x_t = √ᾱ_t·x_0 + √(1-ᾱ_t)·ε."""
s1 = _extract(self.sqrt_alphas_cumprod, t, x0.shape)
s2 = _extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
return s1 * x0 + s2 * noise
# ------------------------------------------------------------------
# Training loss
# ------------------------------------------------------------------
def p_losses(
self,
denoise_fn: nn.Module,
x0: torch.Tensor,
t: torch.Tensor,
noise: torch.Tensor | None = None,
) -> torch.Tensor:
"""DDPM simple loss: MSE between true noise and predicted noise."""
if noise is None:
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, noise)
predicted_noise = denoise_fn(xt, t)
if self.loss_type == "l1":
return F.l1_loss(predicted_noise, noise)
return F.mse_loss(predicted_noise, noise)
# ------------------------------------------------------------------
# Reverse diffusion (sampling)
# ------------------------------------------------------------------
@torch.no_grad()
def p_sample(
self, denoise_fn: nn.Module, x: torch.Tensor, t: int, t_tensor: torch.Tensor
) -> torch.Tensor:
"""Single DDPM reverse step: x_t → x_{t-1} with x₀ clipping.
Instead of computing the posterior mean directly from ε, we first
recover a point estimate of x₀ from the predicted noise, CLIP it to
[-1, 1], then compute the posterior mean from the clipped x₀.
This clipping prevents numerical explosion when the model's noise
prediction is imperfect — critical for the cosine schedule where β
values at high t can exceed 0.8 (vs ≤0.02 for linear).
"""
eps = denoise_fn(x, t_tensor)
# Recover predicted x₀ and clip
sr_ac = _extract(self.sqrt_recip_alphas_cumprod, t_tensor, x.shape)
sr_m1_ac = _extract(self.sqrt_recipm1_alphas_cumprod, t_tensor, x.shape)
pred_x0 = sr_ac * x - sr_m1_ac * eps
pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)
# Posterior mean from clipped x₀
coef1 = _extract(self.posterior_mean_coef1, t_tensor, x.shape)
coef2 = _extract(self.posterior_mean_coef2, t_tensor, x.shape)
mean = coef1 * pred_x0 + coef2 * x
if t == 0:
return mean
var = _extract(self.posterior_variance, t_tensor, x.shape)
noise = torch.randn_like(x)
return mean + var.sqrt() * noise
@torch.no_grad()
def p_sample_loop(
self,
denoise_fn: nn.Module,
shape: tuple[int, ...],
device: torch.device,
progress: bool = False,
noise: torch.Tensor | None = None,
) -> torch.Tensor:
"""Full reverse chain: x_T ∼ N(0,I) → … → x_0."""
b = shape[0]
if noise is not None:
img = noise.to(device)
else:
img = torch.randn(shape, device=device)
timestep_range = reversed(range(self.timesteps))
if progress:
from tqdm import tqdm
timestep_range = tqdm(timestep_range, desc="Sampling", leave=False)
for t in timestep_range:
t_tensor = torch.full((b,), t, device=device, dtype=torch.long)
img = self.p_sample(denoise_fn, img, t, t_tensor)
return img