File size: 7,805 Bytes
065eb11 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 | """Gaussian diffusion process — forward (noising) and reverse (sampling).
Implements the DDPM formulation (Ho et al. 2020):
- Linear β schedule
- ε-prediction objective
- DDPM ancestral sampling
"""
from __future__ import annotations
import torch
import torch.nn as nn
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Noise schedules
# ---------------------------------------------------------------------------
def linear_beta_schedule(timesteps: int, start: float = 1e-4, end: float = 0.02) -> torch.Tensor:
"""Linear schedule: β_t grows linearly from `start` to `end`."""
return torch.linspace(start, end, timesteps)
def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor:
"""Cosine schedule (Nichol & Dhariwal 2021).
Produces ᾱ_t following a cosine curve — gives more even noise-level
coverage than the linear schedule, which improves log-likelihood and
sample quality by spending more steps in the mid-noise regime where
the model learns the most about image structure.
Beta values grow naturally at high t (up to ~0.8 for the final steps).
This is by design: the corresponding sqrt_recip_alpha and coef_eps
coefficients compensate, keeping the reverse process mathematically
consistent. Clamping betas would break this consistency.
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1.0 - alphas_cumprod[1:] / alphas_cumprod[:-1]
return torch.clamp(betas, min=1e-5, max=0.999)
# ---------------------------------------------------------------------------
# Helper
# ---------------------------------------------------------------------------
def _extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor:
"""Gather values from 1-D tensor `a` at indices `t` and broadcast to `x_shape`."""
b = t.shape[0]
out = a[t.cpu()] # a is on CPU, t may be on CUDA — index on CPU
return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device)
# ---------------------------------------------------------------------------
# Gaussian Diffusion
# ---------------------------------------------------------------------------
class GaussianDiffusion:
"""DDPM forward/reverse process with pre-computed coefficients."""
def __init__(
self,
timesteps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 0.02,
loss_type: str = "l2",
schedule: str = "cosine",
):
self.timesteps = timesteps
# ── β schedule ──
if schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
elif schedule == "linear":
betas = linear_beta_schedule(timesteps, beta_start, beta_end)
else:
raise ValueError(f"Unknown schedule: {schedule}")
# ── Precompute coefficients ──
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# Forward process
self.sqrt_alphas_cumprod = alphas_cumprod.sqrt()
self.sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod).sqrt()
# Reverse process — x₀ recovery (for clipping)
self.sqrt_recip_alphas_cumprod = (1.0 / alphas_cumprod).sqrt()
self.sqrt_recipm1_alphas_cumprod = (1.0 / alphas_cumprod - 1.0).sqrt()
# Reverse process — posterior mean via x₀ (Improved DDPM eq. 9)
self.posterior_mean_coef1 = (
betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod)
)
self.posterior_mean_coef2 = (
(1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod)
)
self.posterior_variance = (
betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
)
# Loss
self.loss_type = loss_type
# ------------------------------------------------------------------
# Forward diffusion
# ------------------------------------------------------------------
def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""Sample x_t given x_0: x_t = √ᾱ_t·x_0 + √(1-ᾱ_t)·ε."""
s1 = _extract(self.sqrt_alphas_cumprod, t, x0.shape)
s2 = _extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
return s1 * x0 + s2 * noise
# ------------------------------------------------------------------
# Training loss
# ------------------------------------------------------------------
def p_losses(
self,
denoise_fn: nn.Module,
x0: torch.Tensor,
t: torch.Tensor,
noise: torch.Tensor | None = None,
) -> torch.Tensor:
"""DDPM simple loss: MSE between true noise and predicted noise."""
if noise is None:
noise = torch.randn_like(x0)
xt = self.q_sample(x0, t, noise)
predicted_noise = denoise_fn(xt, t)
if self.loss_type == "l1":
return F.l1_loss(predicted_noise, noise)
return F.mse_loss(predicted_noise, noise)
# ------------------------------------------------------------------
# Reverse diffusion (sampling)
# ------------------------------------------------------------------
@torch.no_grad()
def p_sample(
self, denoise_fn: nn.Module, x: torch.Tensor, t: int, t_tensor: torch.Tensor
) -> torch.Tensor:
"""Single DDPM reverse step: x_t → x_{t-1} with x₀ clipping.
Instead of computing the posterior mean directly from ε, we first
recover a point estimate of x₀ from the predicted noise, CLIP it to
[-1, 1], then compute the posterior mean from the clipped x₀.
This clipping prevents numerical explosion when the model's noise
prediction is imperfect — critical for the cosine schedule where β
values at high t can exceed 0.8 (vs ≤0.02 for linear).
"""
eps = denoise_fn(x, t_tensor)
# Recover predicted x₀ and clip
sr_ac = _extract(self.sqrt_recip_alphas_cumprod, t_tensor, x.shape)
sr_m1_ac = _extract(self.sqrt_recipm1_alphas_cumprod, t_tensor, x.shape)
pred_x0 = sr_ac * x - sr_m1_ac * eps
pred_x0 = torch.clamp(pred_x0, -1.0, 1.0)
# Posterior mean from clipped x₀
coef1 = _extract(self.posterior_mean_coef1, t_tensor, x.shape)
coef2 = _extract(self.posterior_mean_coef2, t_tensor, x.shape)
mean = coef1 * pred_x0 + coef2 * x
if t == 0:
return mean
var = _extract(self.posterior_variance, t_tensor, x.shape)
noise = torch.randn_like(x)
return mean + var.sqrt() * noise
@torch.no_grad()
def p_sample_loop(
self,
denoise_fn: nn.Module,
shape: tuple[int, ...],
device: torch.device,
progress: bool = False,
noise: torch.Tensor | None = None,
) -> torch.Tensor:
"""Full reverse chain: x_T ∼ N(0,I) → … → x_0."""
b = shape[0]
if noise is not None:
img = noise.to(device)
else:
img = torch.randn(shape, device=device)
timestep_range = reversed(range(self.timesteps))
if progress:
from tqdm import tqdm
timestep_range = tqdm(timestep_range, desc="Sampling", leave=False)
for t in timestep_range:
t_tensor = torch.full((b,), t, device=device, dtype=torch.long)
img = self.p_sample(denoise_fn, img, t, t_tensor)
return img
|