lsnu's picture
Add files using upload-large-folder tool
912c7e2 verified
import torch
def get_timesteps(schedule: str, k_steps: int, exp_scale: float = 1.0):
t = torch.linspace(0, 1, k_steps + 1)[:-1]
if schedule == "linear":
dt = torch.ones(k_steps) / k_steps
elif schedule == "cosine":
dt = torch.cos(t * torch.pi) + 1
dt /= torch.sum(dt)
elif schedule == "exp":
dt = torch.exp(-t * exp_scale)
dt /= torch.sum(dt)
else:
raise ValueError(f"Invalid schedule: {schedule}")
t0 = torch.cat((torch.zeros(1), torch.cumsum(dt, dim=0)[:-1]))
return t0, dt