| from typing import Union |
| import torch |
|
|
|
|
| class Timesteps: |
| """ |
| Sampling timesteps. |
| It defines the discretization of sampling steps. |
| """ |
|
|
| def __init__( |
| self, |
| T: int, |
| steps: int, |
| device: torch.device = "cpu", |
| ): |
| self.T = T |
| timesteps = torch.arange(T, -1, -(T + 1) / steps, device=device).round().int() |
| self.timesteps = timesteps |
|
|
| def __len__(self) -> int: |
| """ |
| Number of sampling steps. |
| """ |
| return len(self.timesteps) |
|
|
| def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: |
| return self.timesteps[idx] |
|
|
| def index(self, t: torch.Tensor) -> torch.Tensor: |
| """ |
| Find index by t. |
| Return index of the same shape as t. |
| Index is -1 if t not found in timesteps. |
| """ |
| i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) |
| idx = torch.full_like(t, fill_value=-1, dtype=torch.int) |
| idx.view(-1)[i] = j.int() |
| return idx |