DDPM-6param / src /diffusion_conditional.py
collins909's picture
Upload 6-parameter conditional DDPM (HI emulation, CAMELS LH params_6, best checkpoint)
c46900a verified
"""
Diffusion Process Implementation (DDPM + DDIM)
Changes from original:
- Fixed q_posterior_mean_variance return value count (was 3, caller expected 4)
- Fixed DDIM sigma/dir_xt formula inconsistency
- Made GaussianDiffusion an nn.Module with registered buffers for proper device handling
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GaussianDiffusion(nn.Module):
def __init__(self, timesteps=1500, beta_start=1e-4, beta_end=0.02, schedule_type="linear"):
super().__init__()
self.timesteps = timesteps
if schedule_type == "linear":
betas = torch.linspace(beta_start, beta_end, timesteps)
elif schedule_type == "cosine":
betas = self._cosine_beta_schedule(timesteps)
else:
raise ValueError(f"Unknown schedule: {schedule_type}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# Register all schedule tensors as buffers so they move with .to(device)
self.register_buffer('betas', betas)
self.register_buffer('alphas', alphas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod))
# Precompute reciprocals used in _predict_xstart_from_noise (avoids recomputation per step)
self.register_buffer('recip_sqrt_alphas_cumprod', 1.0 / torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_recip_minus_one', torch.sqrt(1.0 / alphas_cumprod - 1.0))
def _cosine_beta_schedule(self, timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def p_mean_variance(self, model, x_t, t, labels, clip_denoised=True):
pred_noise = model(x_t, t, labels)
x_start = self._predict_xstart_from_noise(x_t, t, pred_noise)
if clip_denoised:
x_start = torch.clamp(x_start, -1.0, 1.0)
# FIX: q_posterior_mean_variance returns 3 values, not 4
model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_start, x_t, t)
return model_mean, posterior_variance, posterior_log_variance, x_start
def _predict_xstart_from_noise(self, x_t, t, noise):
return (
self._extract(self.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recip_minus_one, t, x_t.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance
def p_sample(self, model, x_t, t, labels):
model_mean, _, model_log_variance, _ = self.p_mean_variance(model, x_t, t, labels)
noise = torch.randn_like(x_t)
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
return model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
def ddim_sample_step(self, model, x_t, t, t_next, labels, eta=0.0):
pred_noise = model(x_t, t, labels)
alpha_t = self._extract(self.alphas_cumprod, t, x_t.shape)
alpha_t_next = self._extract(self.alphas_cumprod, t_next, x_t.shape) if t_next[0] >= 0 else torch.ones_like(alpha_t)
x0_pred = (x_t - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
x0_pred = torch.clamp(x0_pred, -1.0, 1.0)
# FIX: Consistent DDIM sigma computation
# sigma^2 = eta^2 * (1 - alpha_{t-1}) / (1 - alpha_t) * (1 - alpha_t / alpha_{t-1})
sigma_sq = eta**2 * (1 - alpha_t_next) / (1 - alpha_t) * (1 - alpha_t / alpha_t_next) if eta > 0 else 0
sigma_t = torch.sqrt(torch.clamp(sigma_sq, min=0)) if eta > 0 else 0
# dir_xt uses the same sigma^2
dir_xt = torch.sqrt(torch.clamp(1 - alpha_t_next - sigma_sq, min=0)) * pred_noise
noise = torch.randn_like(x_t) if eta > 0 else 0
return torch.sqrt(alpha_t_next) * x0_pred + dir_xt + sigma_t * noise
def sample(self, model, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0):
batch_size = labels.shape[0]
img = torch.randn((batch_size, channels, height, width), device=device)
if use_ddim:
skip = self.timesteps // ddim_steps
seq = list(range(0, self.timesteps, skip))
seq_next = [-1] + seq[:-1]
seq_iter = reversed(list(zip(seq, seq_next)))
if progress:
from tqdm import tqdm
seq_iter = tqdm(seq_iter, desc=f'DDIM Sampling ({ddim_steps} steps)', total=len(seq))
for i, j in seq_iter:
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
t_next = torch.full((batch_size,), j, device=device, dtype=torch.long)
img = self.ddim_sample_step(model, img, t, t_next, labels, eta)
else:
if progress:
from tqdm import tqdm
timesteps_iter = tqdm(reversed(range(self.timesteps)), total=self.timesteps)
else:
timesteps_iter = reversed(range(self.timesteps))
for i in timesteps_iter:
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
img = self.p_sample(model, img, t, labels)
return img
def training_losses(self, model, x_start, labels, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise)
pred_noise = model(x_t, t, labels)
return F.mse_loss(pred_noise, noise, reduction='none').mean(dim=list(range(1, len(pred_noise.shape))))
def _extract(self, a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(0, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
class ConditionalDiffusionModel(nn.Module):
def __init__(self, unet, diffusion_process):
super().__init__()
self.unet = unet
self.diffusion = diffusion_process
def forward(self, x, t, labels):
return self.unet(x, t, labels)
def get_loss(self, x, labels, noise=None):
batch_size = x.shape[0]
device = x.device
t = torch.randint(0, self.diffusion.timesteps, (batch_size,), device=device).long()
return self.diffusion.training_losses(self, x, labels, t, noise=noise).mean()
def sample(self, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0):
self.eval()
with torch.no_grad():
return self.diffusion.sample(self, labels, channels, height, width, device, progress, use_ddim, ddim_steps, eta)