""" Gaussian Diffusion (DDPM) framework for PDE next-frame prediction. """ import torch import torch.nn as nn import torch.nn.functional as F import math class GaussianDiffusion(nn.Module): """DDPM with linear beta schedule. Training: given (condition, target), add noise to target, predict noise. Sampling: iteratively denoise starting from Gaussian noise. Args: model: U-Net (or any eps-predicting network). timesteps: number of diffusion steps. beta_start: starting noise level. beta_end: ending noise level. """ def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02): super().__init__() self.model = model self.T = timesteps # --- precompute schedule --- betas = torch.linspace(beta_start, beta_end, timesteps) alphas = 1.0 - betas alpha_bar = torch.cumprod(alphas, dim=0) self.register_buffer("betas", betas) self.register_buffer("alphas", alphas) self.register_buffer("alpha_bar", alpha_bar) self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar)) self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar)) self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas)) self.register_buffer( "posterior_variance", betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar), ) def q_sample(self, x0, t, noise=None): """Forward process: add noise to x0 at timestep t.""" if noise is None: noise = torch.randn_like(x0) a = self.sqrt_alpha_bar[t][:, None, None, None] b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None] return a * x0 + b * noise, noise def training_loss(self, x_target, x_cond): """Compute training loss (predict noise). Args: x_target: clean target frames [B, C, H, W]. x_cond: condition frames [B, C, H, W]. Returns: scalar MSE loss. """ B = x_target.shape[0] t = torch.randint(0, self.T, (B,), device=x_target.device) noise = torch.randn_like(x_target) x_noisy, _ = self.q_sample(x_target, t, noise) eps_pred = self.model(x_noisy, t, cond=x_cond) return F.mse_loss(eps_pred, noise) @torch.no_grad() def sample(self, x_cond, shape=None): """Generate target frames by iterative denoising (DDPM). Args: x_cond: condition frames [B, C_cond, H, W]. shape: (B, C_out, H, W) of the target. Inferred if None. Returns: denoised sample [B, C_out, H, W]. """ device = x_cond.device if shape is None: shape = x_cond.shape # assume same channels x = torch.randn(shape, device=device) for i in reversed(range(self.T)): t = torch.full((shape[0],), i, device=device, dtype=torch.long) eps = self.model(x, t, cond=x_cond) alpha = self.alphas[i] alpha_bar = self.alpha_bar[i] beta = self.betas[i] mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps) if i > 0: sigma = self.posterior_variance[i].sqrt() x = mean + sigma * torch.randn_like(x) else: x = mean return x @torch.no_grad() def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0): """DDIM accelerated sampling. Args: x_cond: condition [B, C_cond, H, W]. shape: target shape. steps: number of DDIM steps (< 0: x = x + sigma * torch.randn_like(x) return x