File size: 5,123 Bytes
128cb34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | """VP diffusion math: logSNR schedules, alpha/sigma computation, noise construction."""
from __future__ import annotations
import math
import torch
from torch import Tensor
def alpha_sigma_from_logsnr(lmb: Tensor) -> tuple[Tensor, Tensor]:
"""Compute (alpha, sigma) from logSNR in float32.
VP constraint: alpha^2 + sigma^2 = 1.
"""
lmb32 = lmb.to(dtype=torch.float32)
alpha = torch.sqrt(torch.sigmoid(lmb32))
sigma = torch.sqrt(torch.sigmoid(-lmb32))
return alpha, sigma
def broadcast_time_like(coeff: Tensor, x: Tensor) -> Tensor:
"""Broadcast [B] coefficient to match x for per-sample scaling."""
view_shape = (int(x.shape[0]),) + (1,) * (x.dim() - 1)
return coeff.view(view_shape)
def _cosine_interpolated_params(
logsnr_min: float, logsnr_max: float
) -> tuple[float, float]:
"""Compute (a, b) for cosine-interpolated logSNR schedule.
logsnr(t) = -2 * log(tan(a*t + b))
logsnr(0) = logsnr_max, logsnr(1) = logsnr_min
"""
b = math.atan(math.exp(-0.5 * logsnr_max))
a = math.atan(math.exp(-0.5 * logsnr_min)) - b
return a, b
def cosine_interpolated_logsnr_from_t(
t: Tensor, *, logsnr_min: float, logsnr_max: float
) -> Tensor:
"""Map t in [0,1] to logSNR via cosine-interpolated schedule. Always float32."""
a, b = _cosine_interpolated_params(logsnr_min, logsnr_max)
t32 = t.to(dtype=torch.float32)
a_t = torch.tensor(a, device=t32.device, dtype=torch.float32)
b_t = torch.tensor(b, device=t32.device, dtype=torch.float32)
u = a_t * t32 + b_t
return -2.0 * torch.log(torch.tan(u))
def shifted_cosine_interpolated_logsnr_from_t(
t: Tensor,
*,
logsnr_min: float,
logsnr_max: float,
log_change_high: float = 0.0,
log_change_low: float = 0.0,
) -> Tensor:
"""SiD2 "shifted cosine" schedule: logSNR with resolution-dependent shifts.
lambda(t) = (1-t) * (base(t) + log_change_high) + t * (base(t) + log_change_low)
"""
base = cosine_interpolated_logsnr_from_t(
t, logsnr_min=logsnr_min, logsnr_max=logsnr_max
)
t32 = t.to(dtype=torch.float32)
high = base + float(log_change_high)
low = base + float(log_change_low)
return (1.0 - t32) * high + t32 * low
def get_schedule(schedule_type: str, num_steps: int) -> Tensor:
"""Generate a descending t-schedule in [0, 1] for VP diffusion sampling.
``num_steps`` is the number of function evaluations (NFE = decoder forward
passes). Internally the schedule has ``num_steps + 1`` time points
(including both endpoints).
Args:
schedule_type: "linear" or "cosine".
num_steps: Number of decoder forward passes (NFE), >= 1.
Returns:
Descending 1D tensor with ``num_steps + 1`` elements from ~1.0 to ~0.0.
"""
# NOTE: the upstream training code (src/ode/time_schedules.py) uses a
# different convention where num_steps counts schedule *points* (so NFE =
# num_steps - 1). This export package corrects the off-by-one so that
# num_steps means NFE directly. TODO: align the upstream convention.
n = max(int(num_steps) + 1, 2)
if schedule_type == "linear":
base = torch.linspace(0.0, 1.0, n)
elif schedule_type == "cosine":
i = torch.arange(n, dtype=torch.float32)
base = 0.5 * (1.0 - torch.cos(math.pi * (i / (n - 1))))
else:
raise ValueError(
f"Unsupported schedule type: {schedule_type!r}. Use 'linear' or 'cosine'."
)
# Descending: high t (noisy) -> low t (clean)
return torch.flip(base, dims=[0])
def make_initial_state(
*,
noise: Tensor,
t_start: Tensor,
logsnr_min: float,
logsnr_max: float,
log_change_high: float = 0.0,
log_change_low: float = 0.0,
) -> Tensor:
"""Construct VP initial state x_t0 = sigma_start * noise (since x0=0).
All math in float32.
"""
batch = int(noise.shape[0])
lmb_start = shifted_cosine_interpolated_logsnr_from_t(
t_start.expand(batch).to(dtype=torch.float32),
logsnr_min=logsnr_min,
logsnr_max=logsnr_max,
log_change_high=log_change_high,
log_change_low=log_change_low,
)
_alpha_start, sigma_start = alpha_sigma_from_logsnr(lmb_start)
sigma_view = broadcast_time_like(sigma_start, noise)
return sigma_view * noise.to(dtype=torch.float32)
def sample_noise(
shape: tuple[int, ...],
*,
noise_std: float = 1.0,
seed: int | None = None,
device: torch.device | None = None,
dtype: torch.dtype = torch.float32,
) -> Tensor:
"""Sample Gaussian noise with optional seeding. CPU-seeded for reproducibility."""
if seed is None:
noise = torch.randn(
shape, device=device or torch.device("cpu"), dtype=torch.float32
)
else:
gen = torch.Generator(device="cpu")
gen.manual_seed(int(seed))
noise = torch.randn(shape, generator=gen, device="cpu", dtype=torch.float32)
noise = noise.mul(float(noise_std))
target_device = device if device is not None else torch.device("cpu")
return noise.to(device=target_device, dtype=dtype)
|