Spaces:
Running
on
Zero
Running
on
Zero
| from abc import ABC, abstractmethod | |
| from typing import Sequence, Union | |
| import torch | |
| from ..types import SamplingDirection | |
| class Timesteps(ABC): | |
| """ | |
| Timesteps base class. | |
| """ | |
| def __init__(self, T: Union[int, float]): | |
| assert T > 0 | |
| self._T = T | |
| def T(self) -> Union[int, float]: | |
| """ | |
| Maximum timestep inclusive. | |
| int if discrete, float if continuous. | |
| """ | |
| return self._T | |
| def is_continuous(self) -> bool: | |
| """ | |
| Whether the schedule is continuous. | |
| """ | |
| return isinstance(self.T, float) | |
| class SamplingTimesteps(Timesteps): | |
| """ | |
| Sampling timesteps. | |
| It defines the discretization of sampling steps. | |
| """ | |
| def __init__( | |
| self, | |
| T: Union[int, float], | |
| timesteps: torch.Tensor, | |
| direction: SamplingDirection, | |
| ): | |
| assert timesteps.ndim == 1 | |
| super().__init__(T) | |
| self.timesteps = timesteps | |
| self.direction = direction | |
| def __len__(self) -> int: | |
| """ | |
| Number of sampling steps. | |
| """ | |
| return len(self.timesteps) | |
| def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: | |
| """ | |
| The timestep at the sampling step. | |
| Returns a scalar tensor if idx is int, | |
| or tensor of the same size if idx is a tensor. | |
| """ | |
| return self.timesteps[idx] | |
| def index(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Find index by t. | |
| Return index of the same shape as t. | |
| Index is -1 if t not found in timesteps. | |
| """ | |
| i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) | |
| idx = torch.full_like(t, fill_value=-1, dtype=torch.int) | |
| idx.view(-1)[i] = j.int() | |
| return idx | |