File size: 8,506 Bytes
c46900a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 | """
Diffusion Process Implementation (DDPM + DDIM)
Changes from original:
- Fixed q_posterior_mean_variance return value count (was 3, caller expected 4)
- Fixed DDIM sigma/dir_xt formula inconsistency
- Made GaussianDiffusion an nn.Module with registered buffers for proper device handling
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
class GaussianDiffusion(nn.Module):
def __init__(self, timesteps=1500, beta_start=1e-4, beta_end=0.02, schedule_type="linear"):
super().__init__()
self.timesteps = timesteps
if schedule_type == "linear":
betas = torch.linspace(beta_start, beta_end, timesteps)
elif schedule_type == "cosine":
betas = self._cosine_beta_schedule(timesteps)
else:
raise ValueError(f"Unknown schedule: {schedule_type}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# Register all schedule tensors as buffers so they move with .to(device)
self.register_buffer('betas', betas)
self.register_buffer('alphas', alphas)
self.register_buffer('alphas_cumprod', alphas_cumprod)
self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
self.register_buffer('posterior_variance', posterior_variance)
self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20)))
self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod))
# Precompute reciprocals used in _predict_xstart_from_noise (avoids recomputation per step)
self.register_buffer('recip_sqrt_alphas_cumprod', 1.0 / torch.sqrt(alphas_cumprod))
self.register_buffer('sqrt_recip_minus_one', torch.sqrt(1.0 / alphas_cumprod - 1.0))
def _cosine_beta_schedule(self, timesteps, s=0.008):
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
def q_sample(self, x_start, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape)
sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
def p_mean_variance(self, model, x_t, t, labels, clip_denoised=True):
pred_noise = model(x_t, t, labels)
x_start = self._predict_xstart_from_noise(x_t, t, pred_noise)
if clip_denoised:
x_start = torch.clamp(x_start, -1.0, 1.0)
# FIX: q_posterior_mean_variance returns 3 values, not 4
model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_start, x_t, t)
return model_mean, posterior_variance, posterior_log_variance, x_start
def _predict_xstart_from_noise(self, x_t, t, noise):
return (
self._extract(self.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t -
self._extract(self.sqrt_recip_minus_one, t, x_t.shape) * noise
)
def q_posterior_mean_variance(self, x_start, x_t, t):
posterior_mean = (
self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = self._extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance
def p_sample(self, model, x_t, t, labels):
model_mean, _, model_log_variance, _ = self.p_mean_variance(model, x_t, t, labels)
noise = torch.randn_like(x_t)
nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1))))
return model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise
def ddim_sample_step(self, model, x_t, t, t_next, labels, eta=0.0):
pred_noise = model(x_t, t, labels)
alpha_t = self._extract(self.alphas_cumprod, t, x_t.shape)
alpha_t_next = self._extract(self.alphas_cumprod, t_next, x_t.shape) if t_next[0] >= 0 else torch.ones_like(alpha_t)
x0_pred = (x_t - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t)
x0_pred = torch.clamp(x0_pred, -1.0, 1.0)
# FIX: Consistent DDIM sigma computation
# sigma^2 = eta^2 * (1 - alpha_{t-1}) / (1 - alpha_t) * (1 - alpha_t / alpha_{t-1})
sigma_sq = eta**2 * (1 - alpha_t_next) / (1 - alpha_t) * (1 - alpha_t / alpha_t_next) if eta > 0 else 0
sigma_t = torch.sqrt(torch.clamp(sigma_sq, min=0)) if eta > 0 else 0
# dir_xt uses the same sigma^2
dir_xt = torch.sqrt(torch.clamp(1 - alpha_t_next - sigma_sq, min=0)) * pred_noise
noise = torch.randn_like(x_t) if eta > 0 else 0
return torch.sqrt(alpha_t_next) * x0_pred + dir_xt + sigma_t * noise
def sample(self, model, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0):
batch_size = labels.shape[0]
img = torch.randn((batch_size, channels, height, width), device=device)
if use_ddim:
skip = self.timesteps // ddim_steps
seq = list(range(0, self.timesteps, skip))
seq_next = [-1] + seq[:-1]
seq_iter = reversed(list(zip(seq, seq_next)))
if progress:
from tqdm import tqdm
seq_iter = tqdm(seq_iter, desc=f'DDIM Sampling ({ddim_steps} steps)', total=len(seq))
for i, j in seq_iter:
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
t_next = torch.full((batch_size,), j, device=device, dtype=torch.long)
img = self.ddim_sample_step(model, img, t, t_next, labels, eta)
else:
if progress:
from tqdm import tqdm
timesteps_iter = tqdm(reversed(range(self.timesteps)), total=self.timesteps)
else:
timesteps_iter = reversed(range(self.timesteps))
for i in timesteps_iter:
t = torch.full((batch_size,), i, device=device, dtype=torch.long)
img = self.p_sample(model, img, t, labels)
return img
def training_losses(self, model, x_start, labels, t, noise=None):
if noise is None:
noise = torch.randn_like(x_start)
x_t = self.q_sample(x_start, t, noise)
pred_noise = model(x_t, t, labels)
return F.mse_loss(pred_noise, noise, reduction='none').mean(dim=list(range(1, len(pred_noise.shape))))
def _extract(self, a, t, x_shape):
batch_size = t.shape[0]
out = a.gather(0, t)
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
class ConditionalDiffusionModel(nn.Module):
def __init__(self, unet, diffusion_process):
super().__init__()
self.unet = unet
self.diffusion = diffusion_process
def forward(self, x, t, labels):
return self.unet(x, t, labels)
def get_loss(self, x, labels, noise=None):
batch_size = x.shape[0]
device = x.device
t = torch.randint(0, self.diffusion.timesteps, (batch_size,), device=device).long()
return self.diffusion.training_losses(self, x, labels, t, noise=noise).mean()
def sample(self, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0):
self.eval()
with torch.no_grad():
return self.diffusion.sample(self, labels, channels, height, width, device, progress, use_ddim, ddim_steps, eta)
|