File size: 1,286 Bytes
756b108
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Sampling utilities for Rectified Flow inference."""

import math

import torch


def time_shift(mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
    """
    Apply time shift to timesteps for flow matching schedule.

    Args:
        mu: Time shift parameter (controls schedule steepness)
        sigma: Sigma parameter (typically 1.0)
        t: Timestep tensor with values in (0, 1]

    Returns:
        Shifted timesteps
    """
    return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)


def get_rf_schedule(num_steps: int, mu: float = 1.5, reverse: bool = True) -> list[float]:
    """
    Generate timestep schedule for Rectified Flow sampling.

    Creates a shifted linear schedule that provides better sample quality
    by spending more time at higher noise levels.

    Args:
        num_steps: Number of sampling steps
        mu: Time shift parameter (higher = more time at high noise)
        reverse: If True, returns schedule from t=0 to t=1 (for denoising)

    Returns:
        List of timesteps of length num_steps + 1
    """
    if reverse:
        mu = -mu
    timesteps = torch.linspace(1, 0, num_steps + 1)
    timesteps = time_shift(mu, 1.0, timesteps)
    timesteps = timesteps.tolist()
    return timesteps[::-1] if reverse else timesteps