"""DDIM and DPM++2M samplers for VP diffusion with x-prediction objective.""" from __future__ import annotations from typing import Protocol import torch from torch import Tensor from .vp_diffusion import ( alpha_sigma_from_logsnr, broadcast_time_like, shifted_cosine_interpolated_logsnr_from_t, ) class DecoderForwardFn(Protocol): """Callable that predicts x0 from (x_t, t, latents).""" def __call__( self, x_t: Tensor, t: Tensor, latents: Tensor, *, mask_tokens: bool = False, ) -> Tensor: ... def _reconstruct_eps_from_x0( *, x_t: Tensor, x0_hat: Tensor, alpha: Tensor, sigma: Tensor ) -> Tensor: """Reconstruct eps_hat from (x_t, x0_hat) under VP parameterization. eps_hat = (x_t - alpha * x0_hat) / sigma. All float32. """ alpha_view = broadcast_time_like(alpha, x_t).to(dtype=torch.float32) sigma_view = broadcast_time_like(sigma, x_t).to(dtype=torch.float32) x_t_f32 = x_t.to(torch.float32) x0_f32 = x0_hat.to(torch.float32) return (x_t_f32 - alpha_view * x0_f32) / sigma_view def _ddim_step( *, x0_hat: Tensor, eps_hat: Tensor, alpha_next: Tensor, sigma_next: Tensor, ref: Tensor, ) -> Tensor: """DDIM step: x_next = alpha_next * x0_hat + sigma_next * eps_hat.""" a = broadcast_time_like(alpha_next, ref).to(dtype=torch.float32) s = broadcast_time_like(sigma_next, ref).to(dtype=torch.float32) return a * x0_hat + s * eps_hat def run_ddim( *, forward_fn: DecoderForwardFn, initial_state: Tensor, schedule: Tensor, latents: Tensor, logsnr_min: float, logsnr_max: float, log_change_high: float = 0.0, log_change_low: float = 0.0, pdg_enabled: bool = False, pdg_strength: float = 1.1, device: torch.device | None = None, ) -> Tensor: """Run DDIM sampling loop. Args: forward_fn: Decoder forward function (x_t, t, latents) -> x0_hat. initial_state: Starting noised state [B, C, H, W] in float32. schedule: Descending t-schedule [num_steps] in [0, 1]. latents: Encoder latents [B, bottleneck_dim, h, w]. logsnr_min, logsnr_max: VP schedule endpoints. log_change_high, log_change_low: Shifted-cosine schedule parameters. pdg_enabled: Whether to use token-level Path-Drop Guidance. pdg_strength: CFG-like strength for PDG (use small values: 1.05–1.2). device: Target device. Returns: Denoised samples [B, C, H, W] in float32. """ run_device = device or initial_state.device batch_size = int(initial_state.shape[0]) state = initial_state.to(device=run_device, dtype=torch.float32) # Precompute logSNR, alpha, sigma for all schedule points lmb = shifted_cosine_interpolated_logsnr_from_t( schedule.to(device=run_device), logsnr_min=logsnr_min, logsnr_max=logsnr_max, log_change_high=log_change_high, log_change_low=log_change_low, ) alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb) for i in range(int(schedule.numel()) - 1): t_i = schedule[i] a_t = alpha_sched[i].expand(batch_size) s_t = sigma_sched[i].expand(batch_size) a_next = alpha_sched[i + 1].expand(batch_size) s_next = sigma_sched[i + 1].expand(batch_size) # Model prediction t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32) if pdg_enabled: x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to( torch.float32 ) x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to( torch.float32 ) x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond) else: x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to( torch.float32 ) eps_hat = _reconstruct_eps_from_x0( x_t=state, x0_hat=x0_hat, alpha=a_t, sigma=s_t ) state = _ddim_step( x0_hat=x0_hat, eps_hat=eps_hat, alpha_next=a_next, sigma_next=s_next, ref=state, ) return state def run_dpmpp_2m( *, forward_fn: DecoderForwardFn, initial_state: Tensor, schedule: Tensor, latents: Tensor, logsnr_min: float, logsnr_max: float, log_change_high: float = 0.0, log_change_low: float = 0.0, pdg_enabled: bool = False, pdg_strength: float = 1.1, device: torch.device | None = None, ) -> Tensor: """Run DPM++2M sampling loop. Multi-step solver using exponential integrator formulation in half-lambda space. """ run_device = device or initial_state.device batch_size = int(initial_state.shape[0]) state = initial_state.to(device=run_device, dtype=torch.float32) # Precompute logSNR, alpha, sigma, half-lambda for all schedule points lmb = shifted_cosine_interpolated_logsnr_from_t( schedule.to(device=run_device), logsnr_min=logsnr_min, logsnr_max=logsnr_max, log_change_high=log_change_high, log_change_low=log_change_low, ) alpha_sched, sigma_sched = alpha_sigma_from_logsnr(lmb) half_lambda = 0.5 * lmb.to(torch.float32) x0_prev: Tensor | None = None for i in range(int(schedule.numel()) - 1): t_i = schedule[i] s_t = sigma_sched[i].expand(batch_size) a_next = alpha_sched[i + 1].expand(batch_size) s_next = sigma_sched[i + 1].expand(batch_size) # Model prediction t_vec = t_i.expand(batch_size).to(device=run_device, dtype=torch.float32) if pdg_enabled: x0_uncond = forward_fn(state, t_vec, latents, mask_tokens=True).to( torch.float32 ) x0_cond = forward_fn(state, t_vec, latents, mask_tokens=False).to( torch.float32 ) x0_hat = x0_uncond + pdg_strength * (x0_cond - x0_uncond) else: x0_hat = forward_fn(state, t_vec, latents, mask_tokens=False).to( torch.float32 ) lam_t = half_lambda[i].expand(batch_size) lam_next = half_lambda[i + 1].expand(batch_size) h = (lam_next - lam_t).to(torch.float32) phi_1 = torch.expm1(-h) sigma_ratio = (s_next / s_t).to(torch.float32) if i == 0 or x0_prev is None: # First-order step state = ( sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state - broadcast_time_like(a_next, state).to(torch.float32) * broadcast_time_like(phi_1, state).to(torch.float32) * x0_hat ) else: # Second-order step lam_prev = half_lambda[i - 1].expand(batch_size) h_0 = (lam_t - lam_prev).to(torch.float32) r0 = h_0 / h d1_0 = (x0_hat - x0_prev) / broadcast_time_like(r0, x0_hat) common = broadcast_time_like(a_next, state).to( torch.float32 ) * broadcast_time_like(phi_1, state).to(torch.float32) state = ( sigma_ratio.view(-1, *([1] * (state.dim() - 1))) * state - common * x0_hat - 0.5 * common * d1_0 ) x0_prev = x0_hat return state