| | from abc import abstractmethod
|
| | from functools import partial
|
| |
|
| | import numpy as np
|
| | import torch
|
| |
|
| | from ...modules.diffusionmodules.util import make_beta_schedule
|
| | from ...util import append_zero
|
| |
|
| |
|
| | def generate_roughly_equally_spaced_steps(
|
| | num_substeps: int, max_step: int
|
| | ) -> np.ndarray:
|
| | return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
|
| |
|
| |
|
| | class Discretization:
|
| | def __call__(self, n, do_append_zero=True, device="cpu", flip=False):
|
| | sigmas = self.get_sigmas(n, device=device)
|
| | sigmas = append_zero(sigmas) if do_append_zero else sigmas
|
| | return sigmas if not flip else torch.flip(sigmas, (0,))
|
| |
|
| | @abstractmethod
|
| | def get_sigmas(self, n, device):
|
| | pass
|
| |
|
| |
|
| | class EDMDiscretization(Discretization):
|
| | def __init__(self, sigma_min=0.02, sigma_max=80.0, rho=7.0):
|
| | self.sigma_min = sigma_min
|
| | self.sigma_max = sigma_max
|
| | self.rho = rho
|
| |
|
| | def get_sigmas(self, n, device="cpu"):
|
| | ramp = torch.linspace(0, 1, n, device=device)
|
| | min_inv_rho = self.sigma_min ** (1 / self.rho)
|
| | max_inv_rho = self.sigma_max ** (1 / self.rho)
|
| | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
|
| | return sigmas
|
| |
|
| |
|
| | class LegacyDDPMDiscretization(Discretization):
|
| | def __init__(
|
| | self,
|
| | linear_start=0.00085,
|
| | linear_end=0.0120,
|
| | num_timesteps=1000,
|
| | ):
|
| | super().__init__()
|
| | self.num_timesteps = num_timesteps
|
| | betas = make_beta_schedule(
|
| | "linear", num_timesteps, linear_start=linear_start, linear_end=linear_end
|
| | )
|
| | alphas = 1.0 - betas
|
| | self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
| | self.to_torch = partial(torch.tensor, dtype=torch.float32)
|
| |
|
| | def get_sigmas(self, n, device="cpu"):
|
| | if n < self.num_timesteps:
|
| | timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
|
| | alphas_cumprod = self.alphas_cumprod[timesteps]
|
| | elif n == self.num_timesteps:
|
| | alphas_cumprod = self.alphas_cumprod
|
| | else:
|
| | raise ValueError
|
| |
|
| | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
|
| | sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
| | return torch.flip(sigmas, (0,))
|
| |
|