File size: 4,987 Bytes
70ced22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
"""
Gaussian Diffusion (DDPM) framework for PDE next-frame prediction.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GaussianDiffusion(nn.Module):
"""DDPM with linear beta schedule.
Training: given (condition, target), add noise to target, predict noise.
Sampling: iteratively denoise starting from Gaussian noise.
Args:
model: U-Net (or any eps-predicting network).
timesteps: number of diffusion steps.
beta_start: starting noise level.
beta_end: ending noise level.
"""
def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02):
super().__init__()
self.model = model
self.T = timesteps
# --- precompute schedule ---
betas = torch.linspace(beta_start, beta_end, timesteps)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
self.register_buffer("betas", betas)
self.register_buffer("alphas", alphas)
self.register_buffer("alpha_bar", alpha_bar)
self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar))
self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar))
self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas))
self.register_buffer(
"posterior_variance",
betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar),
)
def q_sample(self, x0, t, noise=None):
"""Forward process: add noise to x0 at timestep t."""
if noise is None:
noise = torch.randn_like(x0)
a = self.sqrt_alpha_bar[t][:, None, None, None]
b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None]
return a * x0 + b * noise, noise
def training_loss(self, x_target, x_cond):
"""Compute training loss (predict noise).
Args:
x_target: clean target frames [B, C, H, W].
x_cond: condition frames [B, C, H, W].
Returns:
scalar MSE loss.
"""
B = x_target.shape[0]
t = torch.randint(0, self.T, (B,), device=x_target.device)
noise = torch.randn_like(x_target)
x_noisy, _ = self.q_sample(x_target, t, noise)
eps_pred = self.model(x_noisy, t, cond=x_cond)
return F.mse_loss(eps_pred, noise)
@torch.no_grad()
def sample(self, x_cond, shape=None):
"""Generate target frames by iterative denoising (DDPM).
Args:
x_cond: condition frames [B, C_cond, H, W].
shape: (B, C_out, H, W) of the target. Inferred if None.
Returns:
denoised sample [B, C_out, H, W].
"""
device = x_cond.device
if shape is None:
shape = x_cond.shape # assume same channels
x = torch.randn(shape, device=device)
for i in reversed(range(self.T)):
t = torch.full((shape[0],), i, device=device, dtype=torch.long)
eps = self.model(x, t, cond=x_cond)
alpha = self.alphas[i]
alpha_bar = self.alpha_bar[i]
beta = self.betas[i]
mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps)
if i > 0:
sigma = self.posterior_variance[i].sqrt()
x = mean + sigma * torch.randn_like(x)
else:
x = mean
return x
@torch.no_grad()
def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0):
"""DDIM accelerated sampling.
Args:
x_cond: condition [B, C_cond, H, W].
shape: target shape.
steps: number of DDIM steps (<<T for speed).
eta: stochasticity (0=deterministic DDIM, 1=DDPM).
Returns:
denoised sample [B, C_out, H, W].
"""
device = x_cond.device
if shape is None:
shape = x_cond.shape
# Sub-sample timesteps uniformly
step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device)
step_indices = step_indices.flip(0) # reverse: T-1 ... 0
x = torch.randn(shape, device=device)
for idx in range(len(step_indices) - 1):
t_cur = step_indices[idx]
t_next = step_indices[idx + 1]
t_batch = t_cur.expand(shape[0])
eps = self.model(x, t_batch, cond=x_cond)
ab_cur = self.alpha_bar[t_cur]
ab_next = self.alpha_bar[t_next]
# Predict x0
x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt()
x0_pred = x0_pred.clamp(-5, 5) # stability clamp
# Direction
sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt()
dir_xt = (1 - ab_next - sigma**2).sqrt() * eps
x = ab_next.sqrt() * x0_pred + dir_xt
if sigma > 0:
x = x + sigma * torch.randn_like(x)
return x
|