""" Diffusion Process Implementation (DDPM + DDIM) Changes from original: - Fixed q_posterior_mean_variance return value count (was 3, caller expected 4) - Fixed DDIM sigma/dir_xt formula inconsistency - Made GaussianDiffusion an nn.Module with registered buffers for proper device handling """ import torch import torch.nn as nn import torch.nn.functional as F import numpy as np class GaussianDiffusion(nn.Module): def __init__(self, timesteps=1500, beta_start=1e-4, beta_end=0.02, schedule_type="linear"): super().__init__() self.timesteps = timesteps if schedule_type == "linear": betas = torch.linspace(beta_start, beta_end, timesteps) elif schedule_type == "cosine": betas = self._cosine_beta_schedule(timesteps) else: raise ValueError(f"Unknown schedule: {schedule_type}") alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) # Register all schedule tensors as buffers so they move with .to(device) self.register_buffer('betas', betas) self.register_buffer('alphas', alphas) self.register_buffer('alphas_cumprod', alphas_cumprod) self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod)) posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) self.register_buffer('posterior_variance', posterior_variance) self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20))) self.register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)) # Precompute reciprocals used in _predict_xstart_from_noise (avoids recomputation per step) self.register_buffer('recip_sqrt_alphas_cumprod', 1.0 / torch.sqrt(alphas_cumprod)) self.register_buffer('sqrt_recip_minus_one', torch.sqrt(1.0 / alphas_cumprod - 1.0)) def _cosine_beta_schedule(self, timesteps, s=0.008): steps = timesteps + 1 x = torch.linspace(0, timesteps, steps) alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) return torch.clip(betas, 0.0001, 0.9999) def q_sample(self, x_start, t, noise=None): if noise is None: noise = torch.randn_like(x_start) sqrt_alphas_cumprod_t = self._extract(self.sqrt_alphas_cumprod, t, x_start.shape) sqrt_one_minus_alphas_cumprod_t = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise def p_mean_variance(self, model, x_t, t, labels, clip_denoised=True): pred_noise = model(x_t, t, labels) x_start = self._predict_xstart_from_noise(x_t, t, pred_noise) if clip_denoised: x_start = torch.clamp(x_start, -1.0, 1.0) # FIX: q_posterior_mean_variance returns 3 values, not 4 model_mean, posterior_variance, posterior_log_variance = self.q_posterior_mean_variance(x_start, x_t, t) return model_mean, posterior_variance, posterior_log_variance, x_start def _predict_xstart_from_noise(self, x_t, t, noise): return ( self._extract(self.recip_sqrt_alphas_cumprod, t, x_t.shape) * x_t - self._extract(self.sqrt_recip_minus_one, t, x_t.shape) * noise ) def q_posterior_mean_variance(self, x_start, x_t, t): posterior_mean = ( self._extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t.shape) * x_t ) posterior_variance = self._extract(self.posterior_variance, t, x_t.shape) posterior_log_variance = self._extract(self.posterior_log_variance_clipped, t, x_t.shape) return posterior_mean, posterior_variance, posterior_log_variance def p_sample(self, model, x_t, t, labels): model_mean, _, model_log_variance, _ = self.p_mean_variance(model, x_t, t, labels) noise = torch.randn_like(x_t) nonzero_mask = ((t != 0).float().view(-1, *([1] * (len(x_t.shape) - 1)))) return model_mean + nonzero_mask * torch.exp(0.5 * model_log_variance) * noise def ddim_sample_step(self, model, x_t, t, t_next, labels, eta=0.0): pred_noise = model(x_t, t, labels) alpha_t = self._extract(self.alphas_cumprod, t, x_t.shape) alpha_t_next = self._extract(self.alphas_cumprod, t_next, x_t.shape) if t_next[0] >= 0 else torch.ones_like(alpha_t) x0_pred = (x_t - torch.sqrt(1 - alpha_t) * pred_noise) / torch.sqrt(alpha_t) x0_pred = torch.clamp(x0_pred, -1.0, 1.0) # FIX: Consistent DDIM sigma computation # sigma^2 = eta^2 * (1 - alpha_{t-1}) / (1 - alpha_t) * (1 - alpha_t / alpha_{t-1}) sigma_sq = eta**2 * (1 - alpha_t_next) / (1 - alpha_t) * (1 - alpha_t / alpha_t_next) if eta > 0 else 0 sigma_t = torch.sqrt(torch.clamp(sigma_sq, min=0)) if eta > 0 else 0 # dir_xt uses the same sigma^2 dir_xt = torch.sqrt(torch.clamp(1 - alpha_t_next - sigma_sq, min=0)) * pred_noise noise = torch.randn_like(x_t) if eta > 0 else 0 return torch.sqrt(alpha_t_next) * x0_pred + dir_xt + sigma_t * noise def sample(self, model, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0): batch_size = labels.shape[0] img = torch.randn((batch_size, channels, height, width), device=device) if use_ddim: skip = self.timesteps // ddim_steps seq = list(range(0, self.timesteps, skip)) seq_next = [-1] + seq[:-1] seq_iter = reversed(list(zip(seq, seq_next))) if progress: from tqdm import tqdm seq_iter = tqdm(seq_iter, desc=f'DDIM Sampling ({ddim_steps} steps)', total=len(seq)) for i, j in seq_iter: t = torch.full((batch_size,), i, device=device, dtype=torch.long) t_next = torch.full((batch_size,), j, device=device, dtype=torch.long) img = self.ddim_sample_step(model, img, t, t_next, labels, eta) else: if progress: from tqdm import tqdm timesteps_iter = tqdm(reversed(range(self.timesteps)), total=self.timesteps) else: timesteps_iter = reversed(range(self.timesteps)) for i in timesteps_iter: t = torch.full((batch_size,), i, device=device, dtype=torch.long) img = self.p_sample(model, img, t, labels) return img def training_losses(self, model, x_start, labels, t, noise=None): if noise is None: noise = torch.randn_like(x_start) x_t = self.q_sample(x_start, t, noise) pred_noise = model(x_t, t, labels) return F.mse_loss(pred_noise, noise, reduction='none').mean(dim=list(range(1, len(pred_noise.shape)))) def _extract(self, a, t, x_shape): batch_size = t.shape[0] out = a.gather(0, t) return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) class ConditionalDiffusionModel(nn.Module): def __init__(self, unet, diffusion_process): super().__init__() self.unet = unet self.diffusion = diffusion_process def forward(self, x, t, labels): return self.unet(x, t, labels) def get_loss(self, x, labels, noise=None): batch_size = x.shape[0] device = x.device t = torch.randint(0, self.diffusion.timesteps, (batch_size,), device=device).long() return self.diffusion.training_losses(self, x, labels, t, noise=noise).mean() def sample(self, labels, channels, height, width, device, progress=False, use_ddim=True, ddim_steps=50, eta=0.0): self.eval() with torch.no_grad(): return self.diffusion.sample(self, labels, channels, height, width, device, progress, use_ddim, ddim_steps, eta)