| from abc import ABC |
| from typing import Sequence, Union |
| import torch |
| from torch.distributions import LogisticNormal |
|
|
| class Timesteps(ABC): |
| """ |
| Timesteps base class. |
| """ |
|
|
| def __init__(self, T: Union[int, float]): |
| assert T > 0 |
| self._T = T |
|
|
| @property |
| def T(self) -> Union[int, float]: |
| """ |
| Maximum timestep inclusive. |
| int if discrete, float if continuous. |
| """ |
| return self._T |
|
|
| def is_continuous(self) -> bool: |
| """ |
| Whether the schedule is continuous. |
| """ |
| return isinstance(self.T, float) |
|
|
| class LogitNormalTrainingTimesteps(Timesteps): |
| """ |
| Logit-Normal sampling of timesteps in [0, T]. |
| """ |
|
|
| def __init__(self, T: Union[int, float], loc: float, scale: float): |
| super().__init__(T) |
| self.dist = LogisticNormal(loc, scale) |
|
|
| def sample( |
| self, |
| size: Sequence[int], |
| device: torch.device = "cpu", |
| ) -> torch.Tensor: |
| t = self.dist.sample(size)[..., 0].to(device).mul_(self.T) |
| return t if self.is_continuous() else t.round().int() |
|
|