|
|
""" |
|
|
Gaussian Diffusion (DDPM) framework for PDE next-frame prediction. |
|
|
""" |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import math |
|
|
|
|
|
|
|
|
class GaussianDiffusion(nn.Module): |
|
|
"""DDPM with linear beta schedule. |
|
|
|
|
|
Training: given (condition, target), add noise to target, predict noise. |
|
|
Sampling: iteratively denoise starting from Gaussian noise. |
|
|
|
|
|
Args: |
|
|
model: U-Net (or any eps-predicting network). |
|
|
timesteps: number of diffusion steps. |
|
|
beta_start: starting noise level. |
|
|
beta_end: ending noise level. |
|
|
""" |
|
|
|
|
|
def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.T = timesteps |
|
|
|
|
|
|
|
|
betas = torch.linspace(beta_start, beta_end, timesteps) |
|
|
alphas = 1.0 - betas |
|
|
alpha_bar = torch.cumprod(alphas, dim=0) |
|
|
|
|
|
self.register_buffer("betas", betas) |
|
|
self.register_buffer("alphas", alphas) |
|
|
self.register_buffer("alpha_bar", alpha_bar) |
|
|
self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar)) |
|
|
self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar)) |
|
|
self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas)) |
|
|
self.register_buffer( |
|
|
"posterior_variance", |
|
|
betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar), |
|
|
) |
|
|
|
|
|
def q_sample(self, x0, t, noise=None): |
|
|
"""Forward process: add noise to x0 at timestep t.""" |
|
|
if noise is None: |
|
|
noise = torch.randn_like(x0) |
|
|
a = self.sqrt_alpha_bar[t][:, None, None, None] |
|
|
b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None] |
|
|
return a * x0 + b * noise, noise |
|
|
|
|
|
def training_loss(self, x_target, x_cond): |
|
|
"""Compute training loss (predict noise). |
|
|
|
|
|
Args: |
|
|
x_target: clean target frames [B, C, H, W]. |
|
|
x_cond: condition frames [B, C, H, W]. |
|
|
|
|
|
Returns: |
|
|
scalar MSE loss. |
|
|
""" |
|
|
B = x_target.shape[0] |
|
|
t = torch.randint(0, self.T, (B,), device=x_target.device) |
|
|
noise = torch.randn_like(x_target) |
|
|
x_noisy, _ = self.q_sample(x_target, t, noise) |
|
|
|
|
|
eps_pred = self.model(x_noisy, t, cond=x_cond) |
|
|
return F.mse_loss(eps_pred, noise) |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample(self, x_cond, shape=None): |
|
|
"""Generate target frames by iterative denoising (DDPM). |
|
|
|
|
|
Args: |
|
|
x_cond: condition frames [B, C_cond, H, W]. |
|
|
shape: (B, C_out, H, W) of the target. Inferred if None. |
|
|
|
|
|
Returns: |
|
|
denoised sample [B, C_out, H, W]. |
|
|
""" |
|
|
device = x_cond.device |
|
|
if shape is None: |
|
|
shape = x_cond.shape |
|
|
|
|
|
x = torch.randn(shape, device=device) |
|
|
|
|
|
for i in reversed(range(self.T)): |
|
|
t = torch.full((shape[0],), i, device=device, dtype=torch.long) |
|
|
eps = self.model(x, t, cond=x_cond) |
|
|
|
|
|
alpha = self.alphas[i] |
|
|
alpha_bar = self.alpha_bar[i] |
|
|
beta = self.betas[i] |
|
|
|
|
|
mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps) |
|
|
|
|
|
if i > 0: |
|
|
sigma = self.posterior_variance[i].sqrt() |
|
|
x = mean + sigma * torch.randn_like(x) |
|
|
else: |
|
|
x = mean |
|
|
|
|
|
return x |
|
|
|
|
|
@torch.no_grad() |
|
|
def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0): |
|
|
"""DDIM accelerated sampling. |
|
|
|
|
|
Args: |
|
|
x_cond: condition [B, C_cond, H, W]. |
|
|
shape: target shape. |
|
|
steps: number of DDIM steps (<<T for speed). |
|
|
eta: stochasticity (0=deterministic DDIM, 1=DDPM). |
|
|
|
|
|
Returns: |
|
|
denoised sample [B, C_out, H, W]. |
|
|
""" |
|
|
device = x_cond.device |
|
|
if shape is None: |
|
|
shape = x_cond.shape |
|
|
|
|
|
|
|
|
step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device) |
|
|
step_indices = step_indices.flip(0) |
|
|
|
|
|
x = torch.randn(shape, device=device) |
|
|
|
|
|
for idx in range(len(step_indices) - 1): |
|
|
t_cur = step_indices[idx] |
|
|
t_next = step_indices[idx + 1] |
|
|
|
|
|
t_batch = t_cur.expand(shape[0]) |
|
|
eps = self.model(x, t_batch, cond=x_cond) |
|
|
|
|
|
ab_cur = self.alpha_bar[t_cur] |
|
|
ab_next = self.alpha_bar[t_next] |
|
|
|
|
|
|
|
|
x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt() |
|
|
x0_pred = x0_pred.clamp(-5, 5) |
|
|
|
|
|
|
|
|
sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt() |
|
|
dir_xt = (1 - ab_next - sigma**2).sqrt() * eps |
|
|
|
|
|
x = ab_next.sqrt() * x0_pred + dir_xt |
|
|
if sigma > 0: |
|
|
x = x + sigma * torch.randn_like(x) |
|
|
|
|
|
return x |
|
|
|