| | """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| |
|
| | def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]: |
| | """Compute (alpha, sigma) from logSNR in float32. |
| | |
| | VP constraint: alpha^2 + sigma^2 = 1. |
| | """ |
| | lmb32 = lmb.to(dtype=torch.float32) |
| | alpha = torch.sqrt(torch.sigmoid(lmb32)) |
| | sigma = torch.sqrt(torch.sigmoid(-lmb32)) |
| | return alpha, sigma |
| |
|
| |
|
| | def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor: |
| | """Broadcast [B] coefficient to match x for per-sample scaling.""" |
| | view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1) |
| | return coeff.view(view_shape) |
| |
|
| |
|
| | def _cosine_interpolated_params( |
| | logsnr_min: float, logsnr_max: float |
| | ) -> tuple[float, float]: |
| | """Compute (a, b) for cosine-interpolated logSNR schedule. |
| | |
| | logsnr(t) = -2 * log(tan(a*t + b)) |
| | logsnr(0) = logsnr_max, logsnr(1) = logsnr_min |
| | """ |
| | b = math.atan(math.exp(-0.5 * logsnr_max)) |
| | a = math.atan(math.exp(-0.5 * logsnr_min)) - b |
| | return a, b |
| |
|
| |
|
| | def cosine_interpolated_logsnr_from_t( |
| | t: Tensor, *, logsnr_min: float, logsnr_max: float |
| | ) -> Tensor: |
| | """Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32.""" |
| | a, b = _cosine_interpolated_params(logsnr_min, logsnr_max) |
| | t32 = t.to(dtype=torch.float32) |
| | a_t = torch.tensor(a, device=t32.device, dtype=torch.float32) |
| | b_t = torch.tensor(b, device=t32.device, dtype=torch.float32) |
| | u = a_t * t32 + b_t |
| | return -2.0 * torch.log(torch.tan(u)) |
| |
|
| |
|
| | def shifted_cosine_interpolated_logsnr_from_t( |
| | t: Tensor, |
| | *, |
| | logsnr_min: float, |
| | logsnr_max: float, |
| | log_change_high: float = 0.0, |
| | log_change_low: float = 0.0, |
| | ) -> Tensor: |
| | """SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts. |
| | |
| | lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low) |
| | """ |
| | base = cosine_interpolated_logsnr_from_t( |
| | t, logsnr_min=logsnr_min, logsnr_max=logsnr_max |
| | ) |
| | t32 = t.to(dtype=torch.float32) |
| | high = base + float(log_change_high) |
| | low = base + float(log_change_low) |
| | return (1.0 - t32) * high + t32 * low |
| |
|
| |
|
| | def get_schedule(schedule_type: str, num_steps: int) -> Tensor: |
| | """Generate a descending t-schedule in [0, 1] for VP diffusion sampling. |
| | |
| | ``num_steps`` is the number of function evaluations (NFE = decoder forward |
| | passes). Internally the schedule has ``num_steps + 1`` time points |
| | (including both endpoints). |
| | |
| | Args: |
| | schedule_type: "linear" or "cosine". |
| | num_steps: Number of decoder forward passes (NFE), >= 1. |
| | |
| | Returns: |
| | Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0. |
| | """ |
| | |
| | |
| | |
| | |
| | n = max(int(num_steps) + 1, 2) |
| | if schedule_type == "linear": |
| | base = torch.linspace(0.0, 1.0, n) |
| | elif schedule_type == "cosine": |
| | i = torch.arange(n, dtype=torch.float32) |
| | base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1)))) |
| | else: |
| | raise ValueError( |
| | f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'." |
| | ) |
| | |
| | return torch.flip(base, dims=[0]) |
| |
|
| |
|
| | def make_initial_state( |
| | *, |
| | noise: Tensor, |
| | t_start: Tensor, |
| | logsnr_min: float, |
| | logsnr_max: float, |
| | log_change_high: float = 0.0, |
| | log_change_low: float = 0.0, |
| | ) -> Tensor: |
| | """Construct VP initial state x_t0 = sigma_start * noise (since x0=0). |
| | |
| | All math in float32. |
| | """ |
| | batch = int(noise.shape[0]) |
| | lmb_start = shifted_cosine_interpolated_logsnr_from_t( |
| | t_start.expand(batch).to(dtype=torch.float32), |
| | logsnr_min=logsnr_min, |
| | logsnr_max=logsnr_max, |
| | log_change_high=log_change_high, |
| | log_change_low=log_change_low, |
| | ) |
| | _alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start) |
| | sigma_view = broadcast_time_like(sigma_start, noise) |
| | return sigma_view * noise.to(dtype=torch.float32) |
| |
|
| |
|
| | def sample_noise( |
| | shape: tuple[int, ...], |
| | *, |
| | noise_std: float = 1.0, |
| | seed: int | None = None, |
| | device: torch.device | None = None, |
| | dtype: torch.dtype = torch.float32, |
| | ) -> Tensor: |
| | """Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility.""" |
| | if seed is None: |
| | noise = torch.randn( |
| | shape, device=device or torch.device("cpu"), dtype=torch.float32 |
| | ) |
| | else: |
| | gen = torch.Generator(device="cpu") |
| | gen.manual_seed(int(seed)) |
| | noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32) |
| | noise = noise.mul(float(noise_std)) |
| | target_device = device if device is not None else torch.device("cpu") |
| | return noise.to(device=target_device, dtype=dtype) |
| |
|