| | """DDIM and DPM++2M samplers for VP diffusion with x-prediction objective.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | from typing import Protocol |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| | from .vp_diffusion import ( |
| | alpha_sigma_from_logsnr, |
| | broadcast_time_like, |
| | shifted_cosine_interpolated_logsnr_from_t, |
| | ) |
| |
|
| |
|
| | class DecoderForwardFn(Protocol): |
| | """Callable that predicts x0 from (x_t, t, latents).""" |
| |
|
| | def __call__( |
| | self, |
| | x_t: Tensor, |
| | t: Tensor, |
| | latents: Tensor, |
| | *, |
| | mask_tokens: bool = False, |
| | ) -> Tensor: ... |
| |
|
| |
|
| | def _reconstruct_eps_from_x0( |
| | *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor |
| | ) -> Tensor: |
| | """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization. |
| | |
| | eps_hat = (x_t - alpha * x0_hat) / sigma. All float32. |
| | """ |
| | alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32) |
| | sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32) |
| | x_t_f32 = x_t.to(torch.float32) |
| | x0_f32 = x0_hat.to(torch.float32) |
| | return (x_t_f32 - alpha_view * x0_f32) / sigma_view |
| |
|
| |
|
| | def _ddim_step( |
| | *, |
| | x0_hat: Tensor, |
| | eps_hat: Tensor, |
| | alpha_next: Tensor, |
| | sigma_next: Tensor, |
| | ref: Tensor, |
| | ) -> Tensor: |
| | """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat.""" |
| | a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32) |
| | s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32) |
| | return a * x0_hat + s * eps_hat |
| |
|
| |
|
| | def run_ddim( |
| | *, |
| | forward_fn: DecoderForwardFn, |
| | initial_state: Tensor, |
| | schedule: Tensor, |
| | latents: Tensor, |
| | logsnr_min: float, |
| | logsnr_max: float, |
| | log_change_high: float = 0.0, |
| | log_change_low: float = 0.0, |
| | pdg_enabled: bool = False, |
| | pdg_strength: float = 1.1, |
| | device: torch.device | None = None, |
| | ) -> Tensor: |
| | """Run DDIM sampling loop. |
| | |
| | Args: |
| | forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat. |
| | initial_state: Starting noised state [B, C, H, W] in float32. |
| | schedule: Descending t-schedule [num_steps] in [0, 1]. |
| | latents: Encoder latents [B, bottleneck_dim, h, w]. |
| | logsnr_min, logsnr_max: VP schedule endpoints. |
| | log_change_high, log_change_low: Shifted-cosine schedule parameters. |
| | pdg_enabled: Whether to use token-level Path-Drop Guidance. |
| | pdg_strength: CFG-like strength for PDG (use small values: 1.05–1.2). |
| | device: Target device. |
| | |
| | Returns: |
| | Denoised samples [B, C, H, W] in float32. |
| | """ |
| | run_device = device or initial_state.device |
| | batch_size = int(initial_state.shape[0]) |
| | state = initial_state.to(device=run_device, dtype=torch.float32) |
| |
|
| | |
| | lmb = shifted_cosine_interpolated_logsnr_from_t( |
| | schedule.to(device=run_device), |
| | logsnr_min=logsnr_min, |
| | logsnr_max=logsnr_max, |
| | log_change_high=log_change_high, |
| | log_change_low=log_change_low, |
| | ) |
| | alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb) |
| |
|
| | for i in range(int(schedule.numel()) - 1): |
| | t_i = schedule[i] |
| | a_t = alpha_sched[i].expand(batch_size) |
| | s_t = sigma_sched[i].expand(batch_size) |
| | a_next = alpha_sched[i + 1].expand(batch_size) |
| | s_next = sigma_sched[i + 1].expand(batch_size) |
| |
|
| | |
| | t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32) |
| | if pdg_enabled: |
| | x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to( |
| | torch.float32 |
| | ) |
| | x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to( |
| | torch.float32 |
| | ) |
| | x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond) |
| | else: |
| | x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to( |
| | torch.float32 |
| | ) |
| |
|
| | eps_hat = _reconstruct_eps_from_x0( |
| | x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t |
| | ) |
| | state = _ddim_step( |
| | x0_hat=x0_hat, |
| | eps_hat=eps_hat, |
| | alpha_next=a_next, |
| | sigma_next=s_next, |
| | ref=state, |
| | ) |
| |
|
| | return state |
| |
|
| |
|
| | def run_dpmpp_2m( |
| | *, |
| | forward_fn: DecoderForwardFn, |
| | initial_state: Tensor, |
| | schedule: Tensor, |
| | latents: Tensor, |
| | logsnr_min: float, |
| | logsnr_max: float, |
| | log_change_high: float = 0.0, |
| | log_change_low: float = 0.0, |
| | pdg_enabled: bool = False, |
| | pdg_strength: float = 1.1, |
| | device: torch.device | None = None, |
| | ) -> Tensor: |
| | """Run DPM++2M sampling loop. |
| | |
| | Multi-step solver using exponential integrator formulation in half-lambda space. |
| | """ |
| | run_device = device or initial_state.device |
| | batch_size = int(initial_state.shape[0]) |
| | state = initial_state.to(device=run_device, dtype=torch.float32) |
| |
|
| | |
| | lmb = shifted_cosine_interpolated_logsnr_from_t( |
| | schedule.to(device=run_device), |
| | logsnr_min=logsnr_min, |
| | logsnr_max=logsnr_max, |
| | log_change_high=log_change_high, |
| | log_change_low=log_change_low, |
| | ) |
| | alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb) |
| | half_lambda = 0.5 * lmb.to(torch.float32) |
| |
|
| | x0_prev: Tensor | None = None |
| |
|
| | for i in range(int(schedule.numel()) - 1): |
| | t_i = schedule[i] |
| | s_t = sigma_sched[i].expand(batch_size) |
| | a_next = alpha_sched[i + 1].expand(batch_size) |
| | s_next = sigma_sched[i + 1].expand(batch_size) |
| |
|
| | |
| | t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32) |
| | if pdg_enabled: |
| | x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to( |
| | torch.float32 |
| | ) |
| | x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to( |
| | torch.float32 |
| | ) |
| | x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond) |
| | else: |
| | x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to( |
| | torch.float32 |
| | ) |
| |
|
| | lam_t = half_lambda[i].expand(batch_size) |
| | lam_next = half_lambda[i + 1].expand(batch_size) |
| | h = (lam_next - lam_t).to(torch.float32) |
| | phi_1 = torch.expm1(-h) |
| |
|
| | sigma_ratio = (s_next / s_t).to(torch.float32) |
| |
|
| | if i == 0 or x0_prev is None: |
| | |
| | state = ( |
| | sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state |
| | - broadcast_time_like(a_next, state).to(torch.float32) |
| | * broadcast_time_like(phi_1, state).to(torch.float32) |
| | * x0_hat |
| | ) |
| | else: |
| | |
| | lam_prev = half_lambda[i - 1].expand(batch_size) |
| | h_0 = (lam_t - lam_prev).to(torch.float32) |
| | r0 = h_0 / h |
| | d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat) |
| | common = broadcast_time_like(a_next, state).to( |
| | torch.float32 |
| | ) * broadcast_time_like(phi_1, state).to(torch.float32) |
| | state = ( |
| | sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state |
| | - common * x0_hat |
| | - 0.5 * common * d1_0 |
| | ) |
| |
|
| | x0_prev = x0_hat |
| |
|
| | return state |
| |
|