| import torch | |
| def get_timesteps(schedule: str, k_steps: int, exp_scale: float = 1.0): | |
| t = torch.linspace(0, 1, k_steps + 1)[:-1] | |
| if schedule == "linear": | |
| dt = torch.ones(k_steps) / k_steps | |
| elif schedule == "cosine": | |
| dt = torch.cos(t * torch.pi) + 1 | |
| dt /= torch.sum(dt) | |
| elif schedule == "exp": | |
| dt = torch.exp(-t * exp_scale) | |
| dt /= torch.sum(dt) | |
| else: | |
| raise ValueError(f"Invalid schedule: {schedule}") | |
| t0 = torch.cat((torch.zeros(1), torch.cumsum(dt, dim=0)[:-1])) | |
| return t0, dt | |