| """Gaussian diffusion process — forward (noising) and reverse (sampling). |
| |
| Implements the DDPM formulation (Ho et al. 2020): |
| - Linear β schedule |
| - ε-prediction objective |
| - DDPM ancestral sampling |
| """ |
|
|
| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| def linear_beta_schedule(timesteps: int, start: float = 1e-4, end: float = 0.02) -> torch.Tensor: |
| """Linear schedule: β_t grows linearly from `start` to `end`.""" |
| return torch.linspace(start, end, timesteps) |
|
|
|
|
| def cosine_beta_schedule(timesteps: int, s: float = 0.008) -> torch.Tensor: |
| """Cosine schedule (Nichol & Dhariwal 2021). |
| |
| Produces ᾱ_t following a cosine curve — gives more even noise-level |
| coverage than the linear schedule, which improves log-likelihood and |
| sample quality by spending more steps in the mid-noise regime where |
| the model learns the most about image structure. |
| |
| Beta values grow naturally at high t (up to ~0.8 for the final steps). |
| This is by design: the corresponding sqrt_recip_alpha and coef_eps |
| coefficients compensate, keeping the reverse process mathematically |
| consistent. Clamping betas would break this consistency. |
| """ |
| steps = timesteps + 1 |
| x = torch.linspace(0, timesteps, steps) |
| alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| betas = 1.0 - alphas_cumprod[1:] / alphas_cumprod[:-1] |
| return torch.clamp(betas, min=1e-5, max=0.999) |
|
|
|
|
| |
| |
| |
|
|
| def _extract(a: torch.Tensor, t: torch.Tensor, x_shape: torch.Size) -> torch.Tensor: |
| """Gather values from 1-D tensor `a` at indices `t` and broadcast to `x_shape`.""" |
| b = t.shape[0] |
| out = a[t.cpu()] |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))).to(t.device) |
|
|
|
|
| |
| |
| |
|
|
| class GaussianDiffusion: |
| """DDPM forward/reverse process with pre-computed coefficients.""" |
|
|
| def __init__( |
| self, |
| timesteps: int = 1000, |
| beta_start: float = 1e-4, |
| beta_end: float = 0.02, |
| loss_type: str = "l2", |
| schedule: str = "cosine", |
| ): |
| self.timesteps = timesteps |
|
|
| |
| if schedule == "cosine": |
| betas = cosine_beta_schedule(timesteps) |
| elif schedule == "linear": |
| betas = linear_beta_schedule(timesteps, beta_start, beta_end) |
| else: |
| raise ValueError(f"Unknown schedule: {schedule}") |
|
|
| |
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
| |
| self.sqrt_alphas_cumprod = alphas_cumprod.sqrt() |
| self.sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod).sqrt() |
|
|
| |
| self.sqrt_recip_alphas_cumprod = (1.0 / alphas_cumprod).sqrt() |
| self.sqrt_recipm1_alphas_cumprod = (1.0 / alphas_cumprod - 1.0).sqrt() |
|
|
| |
| self.posterior_mean_coef1 = ( |
| betas * alphas_cumprod_prev.sqrt() / (1.0 - alphas_cumprod) |
| ) |
| self.posterior_mean_coef2 = ( |
| (1.0 - alphas_cumprod_prev) * alphas.sqrt() / (1.0 - alphas_cumprod) |
| ) |
| self.posterior_variance = ( |
| betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| ) |
|
|
| |
| self.loss_type = loss_type |
|
|
| |
| |
| |
|
|
| def q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: |
| """Sample x_t given x_0: x_t = √ᾱ_t·x_0 + √(1-ᾱ_t)·ε.""" |
| s1 = _extract(self.sqrt_alphas_cumprod, t, x0.shape) |
| s2 = _extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) |
| return s1 * x0 + s2 * noise |
|
|
| |
| |
| |
|
|
| def p_losses( |
| self, |
| denoise_fn: nn.Module, |
| x0: torch.Tensor, |
| t: torch.Tensor, |
| noise: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """DDPM simple loss: MSE between true noise and predicted noise.""" |
| if noise is None: |
| noise = torch.randn_like(x0) |
|
|
| xt = self.q_sample(x0, t, noise) |
| predicted_noise = denoise_fn(xt, t) |
|
|
| if self.loss_type == "l1": |
| return F.l1_loss(predicted_noise, noise) |
| return F.mse_loss(predicted_noise, noise) |
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def p_sample( |
| self, denoise_fn: nn.Module, x: torch.Tensor, t: int, t_tensor: torch.Tensor |
| ) -> torch.Tensor: |
| """Single DDPM reverse step: x_t → x_{t-1} with x₀ clipping. |
| |
| Instead of computing the posterior mean directly from ε, we first |
| recover a point estimate of x₀ from the predicted noise, CLIP it to |
| [-1, 1], then compute the posterior mean from the clipped x₀. |
| |
| This clipping prevents numerical explosion when the model's noise |
| prediction is imperfect — critical for the cosine schedule where β |
| values at high t can exceed 0.8 (vs ≤0.02 for linear). |
| """ |
| eps = denoise_fn(x, t_tensor) |
|
|
| |
| sr_ac = _extract(self.sqrt_recip_alphas_cumprod, t_tensor, x.shape) |
| sr_m1_ac = _extract(self.sqrt_recipm1_alphas_cumprod, t_tensor, x.shape) |
| pred_x0 = sr_ac * x - sr_m1_ac * eps |
| pred_x0 = torch.clamp(pred_x0, -1.0, 1.0) |
|
|
| |
| coef1 = _extract(self.posterior_mean_coef1, t_tensor, x.shape) |
| coef2 = _extract(self.posterior_mean_coef2, t_tensor, x.shape) |
| mean = coef1 * pred_x0 + coef2 * x |
|
|
| if t == 0: |
| return mean |
|
|
| var = _extract(self.posterior_variance, t_tensor, x.shape) |
| noise = torch.randn_like(x) |
| return mean + var.sqrt() * noise |
|
|
| @torch.no_grad() |
| def p_sample_loop( |
| self, |
| denoise_fn: nn.Module, |
| shape: tuple[int, ...], |
| device: torch.device, |
| progress: bool = False, |
| noise: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Full reverse chain: x_T ∼ N(0,I) → … → x_0.""" |
| b = shape[0] |
| if noise is not None: |
| img = noise.to(device) |
| else: |
| img = torch.randn(shape, device=device) |
|
|
| timestep_range = reversed(range(self.timesteps)) |
| if progress: |
| from tqdm import tqdm |
| timestep_range = tqdm(timestep_range, desc="Sampling", leave=False) |
|
|
| for t in timestep_range: |
| t_tensor = torch.full((b,), t, device=device, dtype=torch.long) |
| img = self.p_sample(denoise_fn, img, t, t_tensor) |
|
|
| return img |
|
|