Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| from torch.nn import Linear, Module | |
| class VarianceSchedule(Module): | |
| def __init__(self, num_steps, beta_1, beta_T, mode="linear"): | |
| super().__init__() | |
| # assert mode in ("linear",) | |
| self.num_steps = num_steps | |
| self.beta_1 = beta_1 | |
| self.beta_T = beta_T | |
| self.mode = mode | |
| if mode == "linear": | |
| betas = torch.linspace(beta_1, beta_T, steps=num_steps) | |
| elif mode == "quad": | |
| betas = torch.linspace(beta_1 ** 0.5, beta_T ** 0.5, num_steps) ** 2 | |
| elif mode == "cosine": | |
| cosine_s = 8e-3 | |
| timesteps = torch.arange(num_steps + 1) / num_steps + cosine_s | |
| alphas = timesteps / (1 + cosine_s) * np.pi / 2 | |
| alphas = torch.cos(alphas).pow(2) | |
| betas = 1 - alphas[1:] / alphas[:-1] | |
| betas = betas.clamp(max=0.999) | |
| betas = torch.cat([torch.zeros([1]), betas], dim=0) # Padding | |
| alphas = 1 - betas | |
| log_alphas = torch.log(alphas) | |
| for i in range(1, log_alphas.size(0)): # 1 to T | |
| log_alphas[i] += log_alphas[i - 1] | |
| alpha_bars = log_alphas.exp() | |
| sigmas_flex = torch.sqrt(betas) | |
| sigmas_inflex = torch.zeros_like(sigmas_flex) | |
| for i in range(1, sigmas_flex.size(0)): | |
| sigmas_inflex[i] = ((1 - alpha_bars[i - 1]) / (1 - alpha_bars[i])) * betas[ | |
| i | |
| ] | |
| sigmas_inflex = torch.sqrt(sigmas_inflex) | |
| self.register_buffer("betas", betas) | |
| self.register_buffer("alphas", alphas) | |
| self.register_buffer("alpha_bars", alpha_bars) | |
| self.register_buffer("sigmas_flex", sigmas_flex) | |
| self.register_buffer("sigmas_inflex", sigmas_inflex) | |
| def uniform_sample_t(self, batch_size): | |
| ts = np.random.choice(np.arange(1, self.num_steps + 1), batch_size) | |
| return ts.tolist() | |
| def get_sigmas(self, t, flexibility): | |
| assert 0 <= flexibility and flexibility <= 1 | |
| sigmas = self.sigmas_flex[t] * flexibility + self.sigmas_inflex[t] * ( | |
| 1 - flexibility | |
| ) | |
| return sigmas | |