| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | import math |
| | import torch |
| |
|
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| |
|
| |
|
| | def extract(a, t, x_shape): |
| | batch_size = t.shape[0] |
| | out = a.gather(-1, t) |
| | return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))) |
| |
|
| |
|
| | def cosine_beta_schedule(timesteps, s=0.008): |
| | steps = timesteps + 1 |
| | t = torch.linspace(0, timesteps, steps, dtype=torch.float32) / timesteps |
| | alphas_cumprod = torch.cos((t + s) / (1 + s) * math.pi * 0.5) ** 2 |
| | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] |
| | betas = 1.0 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) |
| | return torch.clip(betas, 0, 0.999) |
| |
|
| |
|
| | class GaussianDiffusion(nn.Module): |
| | """ |
| | Core diffusion module that wraps a denoiser (UNet): |
| | - Precomputes diffusion constants (betas, alphas, etc.) |
| | - Provides training loss (forward): randomly pick t, add noise, regress target |
| | - Provides sampling loops (DDPM or DDIM) |
| | |
| | The denoiser must have forward(x, t, [x_self_cond]), returning a predicted target |
| | (epsilon, x0, or v depending on `objective`). |
| | """ |
| |
|
| | def __init__(self, model, *, image_size, timesteps=400, beta_schedule='cosine', |
| | objective='pred_noise', sampling_steps=None, eta=0.0, |
| | self_condition=False, auto_normalize=True, clamp_x0=True): |
| | """ |
| | Args: |
| | model (nn.Module): denoiser network (e.g., UNet). |
| | image_size (int or (h,w)): training/sampling resolution (must match UNet). |
| | timesteps (int): T. Smaller (e.g., 400) is enough for MNIST. |
| | beta_schedule (str): only 'cosine' implemented here for simplicity. |
| | objective (str): 'pred_noise'|'pred_x0'|'pred_v' (training target). |
| | sampling_steps (int or None): if set < T => DDIM sampling with S steps; else DDPM full T. |
| | eta (float): DDIM stochasticity (0.0 => deterministic). |
| | self_condition (bool): optional self-conditioning flag. |
| | auto_normalize (bool): map inputs [0,1] <-> [-1,1] inside module. |
| | clamp_x0 (bool): clamp predicted x0 to [-1,1] during sampling for stability. |
| | """ |
| | super().__init__() |
| | self.model = model |
| | param = next(model.parameters()) |
| | param_dtype = param.dtype |
| | param_device = param.device |
| | self.channels = model.channels |
| | self.self_condition = self_condition |
| | self.objective = objective |
| | self.clamp_x0 = clamp_x0 |
| |
|
| | |
| | self.normalize = (lambda x: x * 2 - |
| | 1) if auto_normalize else (lambda x: x) |
| | self.unnormalize = (lambda x: (x + 1) * |
| | 0.5) if auto_normalize else (lambda x: x) |
| |
|
| | |
| | if isinstance(image_size, int): |
| | image_size = (image_size, image_size) |
| | self.image_size = image_size |
| |
|
| | |
| | if beta_schedule != 'cosine': |
| | raise NotImplementedError( |
| | "For MNIST small, keep beta_schedule='cosine'") |
| | betas = cosine_beta_schedule(timesteps).to( |
| | device=param_device, dtype=param_dtype) |
| |
|
| | alphas = 1.0 - betas |
| | alphas_cumprod = torch.cumprod(alphas, dim=0) |
| | alphas_cumprod_prev = F.pad( |
| | alphas_cumprod[:-1], (1, 0), value=1.0) |
| |
|
| | |
| | self.num_timesteps = int(betas.shape[0]) |
| | self.sampling_steps = int( |
| | sampling_steps) if sampling_steps else self.num_timesteps |
| | self.is_ddim_sampling = self.sampling_steps < self.num_timesteps |
| | self.ddim_sampling_eta = float(eta) |
| |
|
| | |
| | self.register_buffer('betas', betas) |
| | self.register_buffer('alphas_cumprod', alphas_cumprod) |
| | self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) |
| | self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) |
| | self.register_buffer('sqrt_one_minus_alphas_cumprod', |
| | torch.sqrt(1.0 - alphas_cumprod)) |
| | self.register_buffer('sqrt_recip_alphas_cumprod', |
| | torch.sqrt(1.0 / alphas_cumprod)) |
| | self.register_buffer('sqrt_recipm1_alphas_cumprod', |
| | torch.sqrt(1.0 / alphas_cumprod - 1.0)) |
| |
|
| | |
| | posterior_variance = betas * \ |
| | (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) |
| | self.register_buffer('posterior_variance', posterior_variance) |
| | self.register_buffer('posterior_log_variance_clipped', torch.log( |
| | posterior_variance.clamp(min=1e-20))) |
| | self.register_buffer('posterior_mean_coef1', betas * |
| | torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) |
| | self.register_buffer('posterior_mean_coef2', (1.0 - alphas_cumprod_prev) |
| | * torch.sqrt(1.0 - betas) / (1.0 - alphas_cumprod)) |
| |
|
| | |
| | snr = alphas_cumprod / (1 - alphas_cumprod) |
| | if objective == 'pred_noise': |
| | loss_weight = snr / snr |
| | elif objective == 'pred_x0': |
| | loss_weight = snr |
| | else: |
| | loss_weight = snr / (snr + 1) |
| | self.register_buffer('loss_weight', loss_weight) |
| |
|
| | @property |
| | def device(self): |
| | """Convenience: returns the device where buffers live.""" |
| | return self.betas.device |
| |
|
| | |
| | |
| | |
| | def q_sample(self, x0, t, noise=None): |
| | """ |
| | Sample x_t from q(x_t | x_0): |
| | x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * noise |
| | """ |
| | if noise is None: |
| | noise = torch.randn_like(x0) |
| | return extract(self.sqrt_alphas_cumprod, t, x0.shape) * x0 + \ |
| | extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise |
| |
|
| | |
| | |
| | |
| | def predict_start_from_noise(self, x_t, t, eps): |
| | """Given epsilon prediction, reconstruct x0.""" |
| | return extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - \ |
| | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps |
| |
|
| | def predict_noise_from_start(self, x_t, t, x0): |
| | """Given x0 prediction, reconstruct epsilon.""" |
| | return (extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \ |
| | extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
| |
|
| | def predict_v(self, x0, t, eps): |
| | """v-parameterization = sqrt(alpha_bar)*eps - sqrt(1-alpha_bar)*x0.""" |
| | return extract(self.alphas_cumprod.sqrt(), t, x0.shape) * eps - \ |
| | extract((1.0 - self.alphas_cumprod).sqrt(), t, x0.shape) * x0 |
| |
|
| | def predict_start_from_v(self, x_t, t, v): |
| | """Given v prediction, reconstruct x0.""" |
| | return extract(self.alphas_cumprod.sqrt(), t, x_t.shape) * x_t - \ |
| | extract((1.0 - self.alphas_cumprod).sqrt(), t, x_t.shape) * v |
| |
|
| | |
| | |
| | |
| | def model_predictions(self, x, t, x_self_cond=None, clip_x_start=False, rederive_pred_noise=False): |
| | """ |
| | Run the denoiser and return (pred_noise, x0): |
| | - If objective == pred_noise: UNet predicts epsilon directly. |
| | - If objective == pred_x0: UNet predicts x0 directly. |
| | - If objective == pred_v: UNet predicts v; we convert to x0 & epsilon. |
| | |
| | Args: |
| | x (Tensor): noised image x_t. |
| | t (LongTensor): time indices. |
| | x_self_cond (Tensor|None): optional self-conditioning input. |
| | clip_x_start (bool): clamp x0 to [-1,1] after prediction. |
| | rederive_pred_noise (bool): if True, recompute epsilon from clamped x0. |
| | |
| | Returns: |
| | (pred_noise, x0) both shape like x. |
| | """ |
| | out = self.model( |
| | x, t, x_self_cond) if x_self_cond is not None else self.model(x, t) |
| |
|
| | maybe_clip = (lambda z: z.clamp(-1, 1) |
| | ) if clip_x_start else (lambda z: z) |
| |
|
| | if self.objective == 'pred_noise': |
| | pred_noise = out |
| | x0 = self.predict_start_from_noise(x, t, pred_noise) |
| | x0 = maybe_clip(x0) |
| | if clip_x_start and rederive_pred_noise: |
| | pred_noise = self.predict_noise_from_start(x, t, x0) |
| |
|
| | elif self.objective == 'pred_x0': |
| | x0 = maybe_clip(out) |
| | pred_noise = self.predict_noise_from_start(x, t, x0) |
| |
|
| | else: |
| | v = out |
| | x0 = self.predict_start_from_v(x, t, v) |
| | x0 = maybe_clip(x0) |
| | pred_noise = self.predict_noise_from_start(x, t, x0) |
| |
|
| | return pred_noise, x0 |
| |
|
| | def q_posterior(self, x0, x_t, t): |
| | """ |
| | Compute the Gaussian q(x_{t-1} | x_t, x0) parameters: |
| | mean = c1 * x0 + c2 * x_t |
| | var, log_var: closed-form from betas and alpha_bars. |
| | """ |
| | mean = extract(self.posterior_mean_coef1, t, x_t.shape) * x0 + \ |
| | extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| | var = extract(self.posterior_variance, t, x_t.shape) |
| | log_var = extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| | return mean, var, log_var |
| |
|
| | |
| | |
| | |
| | def p_losses(self, x_start, t, noise=None): |
| | """ |
| | DDPM training objective: |
| | - Sample x_t = q(x_t | x_0) |
| | - Predict target according to objective and MSE it |
| | - (Optional) self-conditioning can be added outside for simplicity |
| | """ |
| | noise = torch.randn_like(x_start) if noise is None else noise |
| | x = self.q_sample(x_start, t, noise) |
| |
|
| | x_self_cond = None |
| | if self.self_condition and torch.rand(1, device=self.device) < 0.5: |
| | |
| | with torch.no_grad(): |
| | _, x_self_cond = self.model_predictions( |
| | x, t, None, clip_x_start=True) |
| |
|
| | model_out = self.model( |
| | x, t, x_self_cond) if x_self_cond is not None else self.model(x, t) |
| |
|
| | if self.objective == 'pred_noise': |
| | target = noise |
| | elif self.objective == 'pred_x0': |
| | target = x_start |
| | else: |
| | v = self.predict_v(x_start, t, noise) |
| | target = v |
| |
|
| | |
| | loss = F.mse_loss(model_out, target, reduction='none') |
| | loss = loss.mean(dim=list(range(1, loss.ndim))) |
| | |
| | loss = loss * extract(self.loss_weight, t, loss.shape) |
| | return loss.mean() |
| |
|
| | def forward(self, img): |
| | """ |
| | Training entry point: |
| | - Normalize to [-1,1] |
| | - Draw random timesteps |
| | - Compute loss |
| | """ |
| | img = img.to(device=self.device, dtype=next( |
| | self.model.parameters()).dtype) |
| | b, c, h, w = img.shape |
| | assert ( |
| | h, w) == self.image_size, f"image must be {self.image_size}, got {(h,w)}" |
| | t = torch.randint(0, self.num_timesteps, (b,), |
| | device=img.device).long() |
| | img = self.normalize(img) |
| | return self.p_losses(img, t) |
| |
|
| | |
| | |
| | |
| | @torch.inference_mode() |
| | def p_sample(self, x, t: int, x_self_cond=None): |
| | """ |
| | Compute one reverse step: |
| | - predict (epsilon, x0), compute posterior q(x_{t-1}|x_t, x0) |
| | - sample from that Gaussian (add noise except at t=0) |
| | """ |
| | b = x.shape[0] |
| | tt = torch.full((b,), t, device=self.device, dtype=torch.long) |
| | pred_noise, x0 = self.model_predictions( |
| | x, tt, x_self_cond, clip_x_start=True) |
| | mean, _, log_var = self.q_posterior(x0, x, tt) |
| | noise = torch.randn_like(x) if t > 0 else 0.0 |
| | return mean + (0.5 * log_var).exp() * noise, x0 |
| |
|
| | |
| | |
| | |
| | @torch.inference_mode() |
| | def ddpm_sample(self, shape): |
| | """ |
| | DDPM sampling with T steps (slow, high quality). |
| | """ |
| | img = torch.randn(shape, device=self.device) |
| | x0 = None |
| | for t in reversed(range(self.num_timesteps)): |
| | self_cond = x0 if self.self_condition else None |
| | img, x0 = self.p_sample(img, t, self_cond) |
| | return self.unnormalize(img) |
| |
|
| | @torch.inference_mode() |
| | def ddim_sample(self, shape): |
| | """ |
| | DDIM sampling with S < T steps (fast, often good quality). |
| | Deterministic when eta=0.0. |
| | """ |
| | T, S, eta = self.num_timesteps, self.sampling_steps, self.ddim_sampling_eta |
| | |
| | times = torch.linspace(-1, T - 1, steps=S + 1, |
| | device=self.device).long().flip(0) |
| | pairs = list(zip(times[:-1].tolist(), times[1:].tolist())) |
| |
|
| | img = torch.randn(shape, device=self.device) |
| | x0 = None |
| |
|
| | for t, t_next in pairs: |
| | tt = torch.full( |
| | (shape[0],), t, device=self.device, dtype=torch.long) |
| | pred_noise, x0 = self.model_predictions( |
| | img, tt, None, clip_x_start=True, rederive_pred_noise=True) |
| |
|
| | if t_next < 0: |
| | |
| | img = x0 |
| | continue |
| |
|
| | a_t, a_next = self.alphas_cumprod[t], self.alphas_cumprod[t_next] |
| | sigma = eta * ((1 - a_t / a_next) * |
| | (1 - a_next) / (1 - a_t)).sqrt() |
| | c = (1 - a_next - sigma ** 2).sqrt() |
| | noise = torch.randn_like(img) |
| |
|
| | |
| | img = x0 * a_next.sqrt() + c * pred_noise + sigma * noise |
| |
|
| | return self.unnormalize(img) |
| |
|
| | @torch.inference_mode() |
| | def sample(self, batch_size=16): |
| | """ |
| | Public sampling API: |
| | - choose DDPM or DDIM depending on `sampling_steps` |
| | - returns a batch of images in [0,1] |
| | """ |
| | H, W = self.image_size |
| | fn = self.ddim_sample if self.is_ddim_sampling else self.ddpm_sample |
| | return fn((batch_size, self.channels, H, W)) |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | @torch.inference_mode() |
| | def ddpm_sample_trajectory(self, shape, record_every=50, return_x0=False): |
| | """ |
| | DDPM sampling but also record intermediate frames. |
| | - record_every: save a snapshot every N steps (also includes first/last). |
| | - return_x0: if True, also store predicted x0 at the same checkpoints. |
| | |
| | Returns: |
| | final_img [B,C,H,W] in [0,1], |
| | frames_xt: list of tensors in [0,1], each [B,C,H,W] |
| | frames_x0 (or None): same length as frames_xt if return_x0=True |
| | """ |
| | img = torch.randn(shape, device=self.device) |
| | frames_xt = [] |
| | frames_x0 = [] if return_x0 else None |
| |
|
| | x0 = None |
| | T = self.num_timesteps |
| |
|
| | for t in reversed(range(T)): |
| | |
| | if t == T - 1 or t == 0 or (t % record_every) == 0: |
| | |
| | frames_xt.append(self.unnormalize(img.clamp(-1, 1))) |
| | if return_x0 and x0 is not None: |
| | frames_x0.append(self.unnormalize(x0.clamp(-1, 1))) |
| |
|
| | self_cond = x0 if self.self_condition else None |
| | img, x0 = self.p_sample(img, t, self_cond) |
| |
|
| | |
| | frames_xt.append(self.unnormalize(img.clamp(-1, 1))) |
| | if return_x0: |
| | frames_x0.append(self.unnormalize(x0.clamp(-1, 1))) |
| |
|
| | return self.unnormalize(img), frames_xt, frames_x0 |
| |
|
| | @torch.no_grad() |
| | def forward_noising_trajectory(self, x0, t_values): |
| | """ |
| | Visualize forward diffusion q(x_t | x_0) at selected t. |
| | Args: |
| | x0: clean images in [0,1], [B,C,H,W] |
| | t_values: list/iterable of ints (0..T-1) |
| | |
| | Returns: |
| | frames_xt: list of tensors in [0,1], each [B,C,H,W] |
| | """ |
| | |
| | x0n = self.normalize(x0.to(self.device)) |
| | frames = [] |
| | for t in t_values: |
| | tt = torch.full((x0n.size(0),), int( |
| | t), device=self.device, dtype=torch.long) |
| | xt = self.q_sample(x0n, tt) |
| | |
| | frames.append(self.unnormalize(xt)) |
| | return frames |
| |
|