Spaces:
Sleeping
Sleeping
| from abc import ABC, abstractmethod | |
| import torch | |
| from torch import vmap | |
| from torch.func import jacrev | |
| import math | |
| class Alpha(ABC): | |
| def __init__(self, inverted: bool = False, atol: float = 1e-3, **kwargs): | |
| if not inverted: | |
| # Make sure alpha_0 = 0 | |
| assert torch.allclose( | |
| self(torch.zeros(1,1,1,1)), torch.zeros(1,1,1,1), atol=atol | |
| ) | |
| # Make sure alpha_1 = 1 | |
| assert torch.allclose( | |
| self(torch.ones(1,1,1,1)), torch.ones(1,1,1,1), atol=atol | |
| ) | |
| else: | |
| # Make sure alpha_0 = 1 | |
| assert torch.allclose( | |
| self(torch.zeros(1,1,1,1)), torch.ones(1,1,1,1), atol=atol | |
| ) | |
| # Make sure alpha_1 = 0 | |
| assert torch.allclose( | |
| self(torch.ones(1,1,1,1)), torch.zeros(1,1,1,1), atol=atol | |
| ) | |
| def __call__(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Evaluates alpha_t. Should satisfy: self(0.0) = 0.0, self(1.0) = 1.0. | |
| :param t: time, shape (num_samples, 1, 1, 1) | |
| :return: alpha_t, shape (num_samples, 1, 1, 1) | |
| """ | |
| pass | |
| def dt(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Evaluates d/dt alpha_t. | |
| :param t: time, shape (num_samples, 1, 1, 1) | |
| :return: d/dt alpha_t, shape (num_samples, 1, 1, 1) | |
| """ | |
| t = t.unsqueeze(1) | |
| dt = vmap(jacrev(self))(t) | |
| return dt.view(-1, 1, 1, 1) | |
| class Beta(ABC): | |
| def __init__(self, inverted: bool = False, atol: float = 1e-3, **kwargs): | |
| if not inverted: | |
| # Check beta_0 = 1 | |
| assert torch.allclose( | |
| self(torch.zeros(1,1,1,1)), torch.ones(1,1,1,1), atol=atol | |
| ) | |
| # Check beta_1 = 0 | |
| assert torch.allclose( | |
| self(torch.ones(1,1,1,1)), torch.zeros(1,1,1,1), atol=atol | |
| ) | |
| else: | |
| # Check beta_0 = 0 | |
| assert torch.allclose( | |
| self(torch.zeros(1, 1, 1, 1)), torch.zeros(1, 1, 1, 1), atol=atol | |
| ) | |
| # Check beta_1 = 1 | |
| assert torch.allclose( | |
| self(torch.ones(1, 1, 1, 1)), torch.ones(1, 1, 1, 1), atol=atol | |
| ) | |
| def __call__(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Evaluates beta_t. Should satisfy: self(0.0) = 1.0, self(1.0) = 0.0. | |
| :param t: time, shape (num_samples, 1, 1, 1) | |
| :return: beta_t, shape (num_samples, 1, 1, 1) | |
| """ | |
| pass | |
| def dt(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Evaluates d/dt beta_t. | |
| :param t: time, shape (num_samples, 1, 1, 1) | |
| :return: d/dt beta_t, shape (num_samples, 1, 1, 1) | |
| """ | |
| t = t.unsqueeze(1) | |
| dt = vmap(jacrev(self))(t) | |
| return dt.view(-1, 1, 1, 1) | |
| class LinearAlpha(Alpha): | |
| """ | |
| Implements alpha_t = t. | |
| """ | |
| def __call__(self, t: torch.Tensor) -> torch.Tensor: | |
| return t | |
| def dt(self, t: torch.Tensor) -> torch.Tensor: | |
| return torch.ones_like(t) | |
| class LinearBeta(Beta): | |
| """ | |
| Implements beta_t = 1-t. | |
| """ | |
| def __call__(self, t: torch.Tensor) -> torch.Tensor: | |
| return 1-t | |
| def dt(self, t: torch.Tensor) -> torch.Tensor: | |
| return -torch.ones_like(t) | |
| class CosineAlpha(Alpha): | |
| """ | |
| Cosine Alpha noise schedule. | |
| """ | |
| def __init__(self, s: float = 0.008): | |
| """ | |
| :param s: small value used for stability | |
| """ | |
| self.s = s | |
| self._f_0 = self._calculate_f(torch.tensor(0.0)) | |
| super().__init__(inverted=True) | |
| def _calculate_f(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Calculate f(t) = cos^2((t + s)/(1 + s) * pi/2) | |
| :param t: current timestep | |
| :return: f(t) | |
| """ | |
| angle = (t + self.s) / (1 + self.s) * (math.pi / 2) | |
| return torch.cos(angle) ** 2 | |
| def __call__(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute alpha-bar(t) = f(t) / f(0) where f(t) = cos^2((t + s)/(1 + s) * pi/2) | |
| """ | |
| f_t = self._calculate_f(t) | |
| return f_t / self._f_0 | |
| def dt(self, t: torch.Tensor) -> torch.Tensor: | |
| angle = (t + self.s) / (1 + self.s) * math.pi / 2 | |
| d_angle_dt = math.pi / (2 * (1 + self.s)) | |
| return -2 * torch.cos(angle) * torch.sin(angle) * d_angle_dt / self._f_0 | |
| class CosineBeta(Beta): | |
| """ | |
| Cosine Beta noise schedule. | |
| """ | |
| def __init__(self, alpha: CosineAlpha, dt: float = 1e-3): | |
| """ | |
| :param alpha: cosine alpha noise schedule. | |
| :param dt: small timestep to approximate beta_t via finite differences | |
| """ | |
| self.alpha = alpha | |
| self.dt = dt | |
| super().__init__(inverted=True) | |
| def __call__(self, t: torch.Tensor) -> torch.Tensor: | |
| """ | |
| beta_t = 1 - alpha-bar(t) / alpha-bar(t - dt) | |
| """ | |
| t_prev = torch.clamp(t - self.dt, min=0.0) | |
| alpha_t = self.alpha(t) | |
| alpha_prev = self.alpha(t_prev) | |
| return 1.0 - alpha_t / alpha_prev | |
| def dt(self, t: torch.Tensor) -> torch.Tensor: | |
| t_prev = torch.clamp(t - self.dt, min=0.0) | |
| t_next = torch.clamp(t + self.dt, max=1.0) | |
| beta_prev = self.__call__(t_prev) | |
| beta_next = self.__call__(t_next) | |
| return (beta_next - beta_prev) / (2 * self.dt) |