virtual-tryon / src /fashn_vton /utils /sampling.py
Hemil Ghori
clean deploy
756b108
"""Sampling utilities for Rectified Flow inference."""
import math
import torch
def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
"""
Apply time shift to timesteps for flow matching schedule.
Args:
mu: Time shift parameter (controls schedule steepness)
sigma: Sigma parameter (typically 1.0)
t: Timestep tensor with values in (0, 1]
Returns:
Shifted timesteps
"""
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
def get_rf_schedule(num_steps: int, mu: float = 1.5, reverse: bool = True) -> list[float]:
"""
Generate timestep schedule for Rectified Flow sampling.
Creates a shifted linear schedule that provides better sample quality
by spending more time at higher noise levels.
Args:
num_steps: Number of sampling steps
mu: Time shift parameter (higher = more time at high noise)
reverse: If True, returns schedule from t=0 to t=1 (for denoising)
Returns:
List of timesteps of length num_steps + 1
"""
if reverse:
mu = -mu
timesteps = torch.linspace(1, 0, num_steps + 1)
timesteps = time_shift(mu, 1.0, timesteps)
timesteps = timesteps.tolist()
return timesteps[::-1] if reverse else timesteps