the-well-diffusion / diffusion.py
AlexWortega's picture
Upload diffusion.py with huggingface_hub
70ced22 verified
"""
Gaussian Diffusion (DDPM) framework for PDE next-frame prediction.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class GaussianDiffusion(nn.Module):
"""DDPM with linear beta schedule.
Training: given (condition, target), add noise to target, predict noise.
Sampling: iteratively denoise starting from Gaussian noise.
Args:
model: U-Net (or any eps-predicting network).
timesteps: number of diffusion steps.
beta_start: starting noise level.
beta_end: ending noise level.
"""
def __init__(self, model, timesteps=1000, beta_start=1e-4, beta_end=0.02):
super().__init__()
self.model = model
self.T = timesteps
# --- precompute schedule ---
betas = torch.linspace(beta_start, beta_end, timesteps)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
self.register_buffer("betas", betas)
self.register_buffer("alphas", alphas)
self.register_buffer("alpha_bar", alpha_bar)
self.register_buffer("sqrt_alpha_bar", torch.sqrt(alpha_bar))
self.register_buffer("sqrt_one_minus_alpha_bar", torch.sqrt(1 - alpha_bar))
self.register_buffer("sqrt_recip_alpha", torch.sqrt(1.0 / alphas))
self.register_buffer(
"posterior_variance",
betas * (1.0 - F.pad(alpha_bar[:-1], (1, 0), value=1.0)) / (1.0 - alpha_bar),
)
def q_sample(self, x0, t, noise=None):
"""Forward process: add noise to x0 at timestep t."""
if noise is None:
noise = torch.randn_like(x0)
a = self.sqrt_alpha_bar[t][:, None, None, None]
b = self.sqrt_one_minus_alpha_bar[t][:, None, None, None]
return a * x0 + b * noise, noise
def training_loss(self, x_target, x_cond):
"""Compute training loss (predict noise).
Args:
x_target: clean target frames [B, C, H, W].
x_cond: condition frames [B, C, H, W].
Returns:
scalar MSE loss.
"""
B = x_target.shape[0]
t = torch.randint(0, self.T, (B,), device=x_target.device)
noise = torch.randn_like(x_target)
x_noisy, _ = self.q_sample(x_target, t, noise)
eps_pred = self.model(x_noisy, t, cond=x_cond)
return F.mse_loss(eps_pred, noise)
@torch.no_grad()
def sample(self, x_cond, shape=None):
"""Generate target frames by iterative denoising (DDPM).
Args:
x_cond: condition frames [B, C_cond, H, W].
shape: (B, C_out, H, W) of the target. Inferred if None.
Returns:
denoised sample [B, C_out, H, W].
"""
device = x_cond.device
if shape is None:
shape = x_cond.shape # assume same channels
x = torch.randn(shape, device=device)
for i in reversed(range(self.T)):
t = torch.full((shape[0],), i, device=device, dtype=torch.long)
eps = self.model(x, t, cond=x_cond)
alpha = self.alphas[i]
alpha_bar = self.alpha_bar[i]
beta = self.betas[i]
mean = (1.0 / alpha.sqrt()) * (x - beta / (1 - alpha_bar).sqrt() * eps)
if i > 0:
sigma = self.posterior_variance[i].sqrt()
x = mean + sigma * torch.randn_like(x)
else:
x = mean
return x
@torch.no_grad()
def sample_ddim(self, x_cond, shape=None, steps=50, eta=0.0):
"""DDIM accelerated sampling.
Args:
x_cond: condition [B, C_cond, H, W].
shape: target shape.
steps: number of DDIM steps (<<T for speed).
eta: stochasticity (0=deterministic DDIM, 1=DDPM).
Returns:
denoised sample [B, C_out, H, W].
"""
device = x_cond.device
if shape is None:
shape = x_cond.shape
# Sub-sample timesteps uniformly
step_indices = torch.linspace(0, self.T - 1, steps + 1, dtype=torch.long, device=device)
step_indices = step_indices.flip(0) # reverse: T-1 ... 0
x = torch.randn(shape, device=device)
for idx in range(len(step_indices) - 1):
t_cur = step_indices[idx]
t_next = step_indices[idx + 1]
t_batch = t_cur.expand(shape[0])
eps = self.model(x, t_batch, cond=x_cond)
ab_cur = self.alpha_bar[t_cur]
ab_next = self.alpha_bar[t_next]
# Predict x0
x0_pred = (x - (1 - ab_cur).sqrt() * eps) / ab_cur.sqrt()
x0_pred = x0_pred.clamp(-5, 5) # stability clamp
# Direction
sigma = eta * ((1 - ab_next) / (1 - ab_cur) * (1 - ab_cur / ab_next)).sqrt()
dir_xt = (1 - ab_next - sigma**2).sqrt() * eps
x = ab_next.sqrt() * x0_pred + dir_xt
if sigma > 0:
x = x + sigma * torch.randn_like(x)
return x