Spaces:
Running
Running
File size: 1,286 Bytes
756b108 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | """Sampling utilities for Rectified Flow inference."""
import math
import torch
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""
Apply time shift to timesteps for flow matching schedule.
Args:
mu: Time shift parameter (controls schedule steepness)
sigma: Sigma parameter (typically 1.0)
t: Timestep tensor with values in (0, 1]
Returns:
Shifted timesteps
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_rf_schedule(num_steps: int, mu: float = 1.5, reverse: bool = True) -> list[float]:
"""
Generate timestep schedule for Rectified Flow sampling.
Creates a shifted linear schedule that provides better sample quality
by spending more time at higher noise levels.
Args:
num_steps: Number of sampling steps
mu: Time shift parameter (higher = more time at high noise)
reverse: If True, returns schedule from t=0 to t=1 (for denoising)
Returns:
List of timesteps of length num_steps + 1
"""
if reverse:
mu = -mu
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = time_shift(mu, 1.0, timesteps)
timesteps = timesteps.tolist()
return timesteps[::-1] if reverse else timesteps
|