| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
| import math |
| import torch |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| def extract(a, t, x_shape): |
| batch_size = t.shape[0] |
| out = a.gather(-1, t) |
| return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) |
|
|
|
|
| def cosine_beta_schedule(timesteps, s=0.008): |
| steps = timesteps + 1 |
| t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps |
| alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 |
| alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| return torch.clip(betas, 0, 0.999) |
|
|
|
|
| class GaussianDiffusion(nn.Module): |
| """ |
| Core diffusion module that wraps a denoiser (UNet): |
| - Precomputes diffusion constants (betas, alphas, etc.) |
| - Provides training loss (forward): randomly pick t, add noise, regress target |
| - Provides sampling loops (DDPM or DDIM) |
| |
| The denoiser must have forward(x, t, [x_self_cond]), returning a predicted target |
| (epsilon, x0, or v depending on `objective`). |
| """ |
|
|
| def __init__(self, model, *, image_size, timesteps=400, beta_schedule='cosine', |
| objective='pred_noise', sampling_steps=None, eta=0.0, |
| self_condition=False, auto_normalize=True, clamp_x0=True): |
| """ |
| Args: |
| model (nn.Module): denoiser network (e.g., UNet). |
| image_size (int or (h,w)): training/sampling resolution (must match UNet). |
| timesteps (int): T. Smaller (e.g., 400) is enough for MNIST. |
| beta_schedule (str): only 'cosine' implemented here for simplicity. |
| objective (str): 'pred_noise'|'pred_x0'|'pred_v' (training target). |
| sampling_steps (int or None): if set < T => DDIM sampling with S steps; else DDPM full T. |
| eta (float): DDIM stochasticity (0.0 => deterministic). |
| self_condition (bool): optional self-conditioning flag. |
| auto_normalize (bool): map inputs [0,1] <-> [-1,1] inside module. |
| clamp_x0 (bool): clamp predicted x0 to [-1,1] during sampling for stability. |
| """ |
| super().__init__() |
| self.model = model |
| param = next(model.parameters()) |
| param_dtype = param.dtype |
| param_device = param.device |
| self.channels = model.channels |
| self.self_condition = self_condition |
| self.objective = objective |
| self.clamp_x0 = clamp_x0 |
|
|
| |
| self.normalize = (lambda x: x * 2 - |
| 1) if auto_normalize else (lambda x: x) |
| self.unnormalize = (lambda x: (x + 1) * |
| 0.5) if auto_normalize else (lambda x: x) |
|
|
| |
| if isinstance(image_size, int): |
| image_size = (image_size, image_size) |
| self.image_size = image_size |
|
|
| |
| if beta_schedule != 'cosine': |
| raise NotImplementedError( |
| "For MNIST small, keep beta_schedule='cosine'") |
| betas = cosine_beta_schedule(timesteps).to( |
| device=param_device, dtype=param_dtype) |
|
|
| alphas = 1.0 - betas |
| alphas_cumprod = torch.cumprod(alphas, dim=0) |
| alphas_cumprod_prev = F.pad( |
| alphas_cumprod[:-1], (1, 0), value=1.0) |
|
|
| |
| self.num_timesteps = int(betas.shape[0]) |
| self.sampling_steps = int( |
| sampling_steps) if sampling_steps else self.num_timesteps |
| self.is_ddim_sampling = self.sampling_steps < self.num_timesteps |
| self.ddim_sampling_eta = float(eta) |
|
|
| |
| self.register_buffer('betas', betas) |
| self.register_buffer('alphas_cumprod', alphas_cumprod) |
| self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
| self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| self.register_buffer('sqrt_one_minus_alphas_cumprod', |
| torch.sqrt(1.0 - alphas_cumprod)) |
| self.register_buffer('sqrt_recip_alphas_cumprod', |
| torch.sqrt(1.0 / alphas_cumprod)) |
| self.register_buffer('sqrt_recipm1_alphas_cumprod', |
| torch.sqrt(1.0 / alphas_cumprod - 1.0)) |
|
|
| |
| posterior_variance = betas * \ |
| (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| self.register_buffer('posterior_variance', posterior_variance) |
| self.register_buffer('posterior_log_variance_clipped', torch.log( |
| posterior_variance.clamp(min=1e-20))) |
| self.register_buffer('posterior_mean_coef1', betas * |
| torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) |
| self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) |
| * torch.sqrt(1.0 - betas) / (1.0 - alphas_cumprod)) |
|
|
| |
| snr = alphas_cumprod / (1 - alphas_cumprod) |
| if objective == 'pred_noise': |
| loss_weight = snr / snr |
| elif objective == 'pred_x0': |
| loss_weight = snr |
| else: |
| loss_weight = snr / (snr + 1) |
| self.register_buffer('loss_weight', loss_weight) |
|
|
| @property |
| def device(self): |
| """Convenience: returns the device where buffers live.""" |
| return self.betas.device |
|
|
| |
| |
| |
| def q_sample(self, x0, t, noise=None): |
| """ |
| Sample x_t from q(x_t | x_0): |
| x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise |
| """ |
| if noise is None: |
| noise = torch.randn_like(x0) |
| return extract(self.sqrt_alphas_cumprod, t, x0.shape) * x0 + \ |
| extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise |
|
|
| |
| |
| |
| def predict_start_from_noise(self, x_t, t, eps): |
| """Given epsilon prediction, reconstruct x0.""" |
| return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - \ |
| extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps |
|
|
| def predict_noise_from_start(self, x_t, t, x0): |
| """Given x0 prediction, reconstruct epsilon.""" |
| return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ |
| extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
|
|
| def predict_v(self, x0, t, eps): |
| """v-parameterization = sqrt(alpha_bar)*eps - sqrt(1-alpha_bar)*x0.""" |
| return extract(self.alphas_cumprod.sqrt(), t, x0.shape) * eps - \ |
| extract((1.0 - self.alphas_cumprod).sqrt(), t, x0.shape) * x0 |
|
|
| def predict_start_from_v(self, x_t, t, v): |
| """Given v prediction, reconstruct x0.""" |
| return extract(self.alphas_cumprod.sqrt(), t, x_t.shape) * x_t - \ |
| extract((1.0 - self.alphas_cumprod).sqrt(), t, x_t.shape) * v |
|
|
| |
| |
| |
| def model_predictions(self, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False): |
| """ |
| Run the denoiser and return (pred_noise, x0): |
| - If objective == pred_noise: UNet predicts epsilon directly. |
| - If objective == pred_x0: UNet predicts x0 directly. |
| - If objective == pred_v: UNet predicts v; we convert to x0 & epsilon. |
| |
| Args: |
| x (Tensor): noised image x_t. |
| t (LongTensor): time indices. |
| x_self_cond (Tensor|None): optional self-conditioning input. |
| clip_x_start (bool): clamp x0 to [-1,1] after prediction. |
| rederive_pred_noise (bool): if True, recompute epsilon from clamped x0. |
| |
| Returns: |
| (pred_noise, x0) both shape like x. |
| """ |
| out = self.model( |
| x, t, x_self_cond) if x_self_cond is not None else self.model(x, t) |
|
|
| maybe_clip = (lambda z: z.clamp(-1, 1) |
| ) if clip_x_start else (lambda z: z) |
|
|
| if self.objective == 'pred_noise': |
| pred_noise = out |
| x0 = self.predict_start_from_noise(x, t, pred_noise) |
| x0 = maybe_clip(x0) |
| if clip_x_start and rederive_pred_noise: |
| pred_noise = self.predict_noise_from_start(x, t, x0) |
|
|
| elif self.objective == 'pred_x0': |
| x0 = maybe_clip(out) |
| pred_noise = self.predict_noise_from_start(x, t, x0) |
|
|
| else: |
| v = out |
| x0 = self.predict_start_from_v(x, t, v) |
| x0 = maybe_clip(x0) |
| pred_noise = self.predict_noise_from_start(x, t, x0) |
|
|
| return pred_noise, x0 |
|
|
| def q_posterior(self, x0, x_t, t): |
| """ |
| Compute the Gaussian q(x_{t-1} | x_t, x0) parameters: |
| mean = c1 * x0 + c2 * x_t |
| var, log_var: closed-form from betas and alpha_bars. |
| """ |
| mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x0 + \ |
| extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| var = extract(self.posterior_variance, t, x_t.shape) |
| log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| return mean, var, log_var |
|
|
| |
| |
| |
| def p_losses(self, x_start, t, noise=None): |
| """ |
| DDPM training objective: |
| - Sample x_t = q(x_t | x_0) |
| - Predict target according to objective and MSE it |
| - (Optional) self-conditioning can be added outside for simplicity |
| """ |
| noise = torch.randn_like(x_start) if noise is None else noise |
| x = self.q_sample(x_start, t, noise) |
|
|
| x_self_cond = None |
| if self.self_condition and torch.rand(1, device=self.device) < 0.5: |
| |
| with torch.no_grad(): |
| _, x_self_cond = self.model_predictions( |
| x, t, None, clip_x_start=True) |
|
|
| model_out = self.model( |
| x, t, x_self_cond) if x_self_cond is not None else self.model(x, t) |
|
|
| if self.objective == 'pred_noise': |
| target = noise |
| elif self.objective == 'pred_x0': |
| target = x_start |
| else: |
| v = self.predict_v(x_start, t, noise) |
| target = v |
|
|
| |
| loss = F.mse_loss(model_out, target, reduction='none') |
| loss = loss.mean(dim=list(range(1, loss.ndim))) |
| |
| loss = loss * extract(self.loss_weight, t, loss.shape) |
| return loss.mean() |
|
|
| def forward(self, img): |
| """ |
| Training entry point: |
| - Normalize to [-1,1] |
| - Draw random timesteps |
| - Compute loss |
| """ |
| img = img.to(device=self.device, dtype=next( |
| self.model.parameters()).dtype) |
| b, c, h, w = img.shape |
| assert ( |
| h, w) == self.image_size, f"image must be {self.image_size}, got {(h,w)}" |
| t = torch.randint(0, self.num_timesteps, (b,), |
| device=img.device).long() |
| img = self.normalize(img) |
| return self.p_losses(img, t) |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def p_sample(self, x, t: int, x_self_cond=None): |
| """ |
| Compute one reverse step: |
| - predict (epsilon, x0), compute posterior q(x_{t-1}|x_t, x0) |
| - sample from that Gaussian (add noise except at t=0) |
| """ |
| b = x.shape[0] |
| tt = torch.full((b,), t, device=self.device, dtype=torch.long) |
| pred_noise, x0 = self.model_predictions( |
| x, tt, x_self_cond, clip_x_start=True) |
| mean, _, log_var = self.q_posterior(x0, x, tt) |
| noise = torch.randn_like(x) if t > 0 else 0.0 |
| return mean + (0.5 * log_var).exp() * noise, x0 |
|
|
| |
| |
| |
| @torch.inference_mode() |
| def ddpm_sample(self, shape): |
| """ |
| DDPM sampling with T steps (slow, high quality). |
| """ |
| img = torch.randn(shape, device=self.device) |
| x0 = None |
| for t in reversed(range(self.num_timesteps)): |
| self_cond = x0 if self.self_condition else None |
| img, x0 = self.p_sample(img, t, self_cond) |
| return self.unnormalize(img) |
|
|
| @torch.inference_mode() |
| def ddim_sample(self, shape): |
| """ |
| DDIM sampling with S < T steps (fast, often good quality). |
| Deterministic when eta=0.0. |
| """ |
| T, S, eta = self.num_timesteps, self.sampling_steps, self.ddim_sampling_eta |
| |
| times = torch.linspace(-1, T - 1, steps=S + 1, |
| device=self.device).long().flip(0) |
| pairs = list(zip(times[:-1].tolist(), times[1:].tolist())) |
|
|
| img = torch.randn(shape, device=self.device) |
| x0 = None |
|
|
| for t, t_next in pairs: |
| tt = torch.full( |
| (shape[0],), t, device=self.device, dtype=torch.long) |
| pred_noise, x0 = self.model_predictions( |
| img, tt, None, clip_x_start=True, rederive_pred_noise=True) |
|
|
| if t_next < 0: |
| |
| img = x0 |
| continue |
|
|
| a_t, a_next = self.alphas_cumprod[t], self.alphas_cumprod[t_next] |
| sigma = eta * ((1 - a_t / a_next) * |
| (1 - a_next) / (1 - a_t)).sqrt() |
| c = (1 - a_next - sigma ** 2).sqrt() |
| noise = torch.randn_like(img) |
|
|
| |
| img = x0 * a_next.sqrt() + c * pred_noise + sigma * noise |
|
|
| return self.unnormalize(img) |
|
|
| @torch.inference_mode() |
| def sample(self, batch_size=16): |
| """ |
| Public sampling API: |
| - choose DDPM or DDIM depending on `sampling_steps` |
| - returns a batch of images in [0,1] |
| """ |
| H, W = self.image_size |
| fn = self.ddim_sample if self.is_ddim_sampling else self.ddpm_sample |
| return fn((batch_size, self.channels, H, W)) |
|
|
| |
|
|
| |
| |
| |
|
|
| @torch.inference_mode() |
| def ddpm_sample_trajectory(self, shape, record_every=50, return_x0=False): |
| """ |
| DDPM sampling but also record intermediate frames. |
| - record_every: save a snapshot every N steps (also includes first/last). |
| - return_x0: if True, also store predicted x0 at the same checkpoints. |
| |
| Returns: |
| final_img [B,C,H,W] in [0,1], |
| frames_xt: list of tensors in [0,1], each [B,C,H,W] |
| frames_x0 (or None): same length as frames_xt if return_x0=True |
| """ |
| img = torch.randn(shape, device=self.device) |
| frames_xt = [] |
| frames_x0 = [] if return_x0 else None |
|
|
| x0 = None |
| T = self.num_timesteps |
|
|
| for t in reversed(range(T)): |
| |
| if t == T - 1 or t == 0 or (t % record_every) == 0: |
| |
| frames_xt.append(self.unnormalize(img.clamp(-1, 1))) |
| if return_x0 and x0 is not None: |
| frames_x0.append(self.unnormalize(x0.clamp(-1, 1))) |
|
|
| self_cond = x0 if self.self_condition else None |
| img, x0 = self.p_sample(img, t, self_cond) |
|
|
| |
| frames_xt.append(self.unnormalize(img.clamp(-1, 1))) |
| if return_x0: |
| frames_x0.append(self.unnormalize(x0.clamp(-1, 1))) |
|
|
| return self.unnormalize(img), frames_xt, frames_x0 |
|
|
| @torch.no_grad() |
| def forward_noising_trajectory(self, x0, t_values): |
| """ |
| Visualize forward diffusion q(x_t | x_0) at selected t. |
| Args: |
| x0: clean images in [0,1], [B,C,H,W] |
| t_values: list/iterable of ints (0..T-1) |
| |
| Returns: |
| frames_xt: list of tensors in [0,1], each [B,C,H,W] |
| """ |
| |
| x0n = self.normalize(x0.to(self.device)) |
| frames = [] |
| for t in t_values: |
| tt = torch.full((x0n.size(0),), int( |
| t), device=self.device, dtype=torch.long) |
| xt = self.q_sample(x0n, tt) |
| |
| frames.append(self.unnormalize(xt)) |
| return frames |
|
|