| | from abc import ABC, abstractmethod |
| | from typing import Sequence, Union |
| | import torch |
| |
|
| | from ..types import SamplingDirection |
| |
|
| |
|
| | class Timesteps(ABC): |
| | """ |
| | Timesteps base class. |
| | """ |
| |
|
| | def __init__(self, T: Union[int, float]): |
| | assert T > 0 |
| | self._T = T |
| |
|
| | @property |
| | def T(self) -> Union[int, float]: |
| | """ |
| | Maximum timestep inclusive. |
| | int if discrete, float if continuous. |
| | """ |
| | return self._T |
| |
|
| | def is_continuous(self) -> bool: |
| | """ |
| | Whether the schedule is continuous. |
| | """ |
| | return isinstance(self.T, float) |
| |
|
| |
|
| | class SamplingTimesteps(Timesteps): |
| | """ |
| | Sampling timesteps. |
| | It defines the discretization of sampling steps. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | T: Union[int, float], |
| | timesteps: torch.Tensor, |
| | direction: SamplingDirection, |
| | ): |
| | assert timesteps.ndim == 1 |
| | super().__init__(T) |
| | self.timesteps = timesteps |
| | self.direction = direction |
| |
|
| | def __len__(self) -> int: |
| | """ |
| | Number of sampling steps. |
| | """ |
| | return len(self.timesteps) |
| |
|
| | def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: |
| | """ |
| | The timestep at the sampling step. |
| | Returns a scalar tensor if idx is int, |
| | or tensor of the same size if idx is a tensor. |
| | """ |
| | return self.timesteps[idx] |
| |
|
| | def index(self, t: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Find index by t. |
| | Return index of the same shape as t. |
| | Index is -1 if t not found in timesteps. |
| | """ |
| | i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) |
| | idx = torch.full_like(t, fill_value=-1, dtype=torch.int) |
| | idx.view(-1)[i] = j.int() |
| | return idx |
| |
|