Spaces:
Sleeping
Sleeping
| import abc | |
| import torch | |
| import torch.nn as nn | |
| # Flags required to enable jit fusion kernels | |
| torch._C._jit_set_profiling_mode(False) | |
| torch._C._jit_set_profiling_executor(False) | |
| torch._C._jit_override_can_fuse_on_cpu(True) | |
| torch._C._jit_override_can_fuse_on_gpu(True) | |
| def _sample_categorical(categorical_probs): | |
| gumbel_norm = ( | |
| 1e-10 | |
| - (torch.rand_like(categorical_probs) + 1e-10).log()) | |
| return (categorical_probs / gumbel_norm).argmax(dim=-1) | |
| def _unsqueeze(x, reference): | |
| return x.view( | |
| * x.shape, | |
| * ((1,) * (len(reference.shape) - len(x.shape)))) | |
| def _sample_t(n, device, antithetic_sampling=True, sampling_eps=1e-3): | |
| _eps_t = torch.rand(n, device=device) | |
| if antithetic_sampling: | |
| offset = torch.arange(n, device=device) / n | |
| _eps_t = (_eps_t / n + offset) % 1 | |
| t = (1 - sampling_eps) * _eps_t + sampling_eps | |
| return t | |
| def q_xt( x, move_chance, mask_index): | |
| """Computes the noisy sample xt. | |
| Args: | |
| x: int torch.Tensor with shape (batch_size, | |
| diffusion_model_input_length), input. | |
| move_chance: float torch.Tensor with shape (batch_size, 1). | |
| """ | |
| move_indices = torch.rand( | |
| * x.shape, device=x.device) < move_chance | |
| xt = torch.where(move_indices, mask_index, x) | |
| return xt | |
| def get_noise(config, dtype=torch.float32): | |
| if config.noise.type == 'geometric': | |
| return GeometricNoise(config.noise.sigma_min, | |
| config.noise.sigma_max) | |
| elif config.noise.type == 'loglinear': | |
| return LogLinearNoise() | |
| elif config.noise.type == 'cosine': | |
| return CosineNoise() | |
| elif config.noise.type == 'cosinesqr': | |
| return CosineSqrNoise() | |
| elif config.noise.type == 'linear': | |
| return Linear(config.noise.sigma_min, | |
| config.noise.sigma_max, | |
| dtype) | |
| else: | |
| raise ValueError(f'{config.noise.type} is not a valid noise') | |
| def binary_discretization(z): | |
| z_hard = torch.sign(z) | |
| z_soft = z / torch.norm(z, dim=-1, keepdim=True) | |
| return z_soft + (z_hard - z_soft).detach() | |
| class Noise(abc.ABC, nn.Module): | |
| """ | |
| Baseline forward method to get the total + rate of noise at a timestep | |
| """ | |
| def forward(self, t): | |
| # Assume time goes from 0 to 1 | |
| return self.total_noise(t), self.rate_noise(t) | |
| def rate_noise(self, t): | |
| """ | |
| Rate of change of noise ie g(t) | |
| """ | |
| pass | |
| def total_noise(self, t): | |
| """ | |
| Total noise ie \int_0^t g(t) dt + g(0) | |
| """ | |
| pass | |
| class CosineNoise(Noise): | |
| def __init__(self, eps=1e-3): | |
| super().__init__() | |
| self.eps = eps | |
| def rate_noise(self, t): | |
| cos = (1 - self.eps) * torch.cos(t * torch.pi / 2) | |
| sin = (1 - self.eps) * torch.sin(t * torch.pi / 2) | |
| scale = torch.pi / 2 | |
| return scale * sin / (cos + self.eps) | |
| def total_noise(self, t): | |
| cos = torch.cos(t * torch.pi / 2) | |
| return - torch.log(self.eps + (1 - self.eps) * cos) | |
| class CosineSqrNoise(Noise): | |
| def __init__(self, eps=1e-3): | |
| super().__init__() | |
| self.eps = eps | |
| def rate_noise(self, t): | |
| cos = (1 - self.eps) * ( | |
| torch.cos(t * torch.pi / 2) ** 2) | |
| sin = (1 - self.eps) * torch.sin(t * torch.pi) | |
| scale = torch.pi / 2 | |
| return scale * sin / (cos + self.eps) | |
| def total_noise(self, t): | |
| cos = torch.cos(t * torch.pi / 2) ** 2 | |
| return - torch.log(self.eps + (1 - self.eps) * cos) | |
| class Linear(Noise): | |
| def __init__(self, sigma_min=0, sigma_max=10, dtype=torch.float32): | |
| super().__init__() | |
| self.sigma_min = torch.tensor(sigma_min, dtype=dtype) | |
| self.sigma_max = torch.tensor(sigma_max, dtype=dtype) | |
| def rate_noise(self, t): | |
| return self.sigma_max - self.sigma_min | |
| def total_noise(self, t): | |
| return self.sigma_min + t * (self.sigma_max - self.sigma_min) | |
| def importance_sampling_transformation(self, t): | |
| f_T = torch.log1p(- torch.exp(- self.sigma_max)) | |
| f_0 = torch.log1p(- torch.exp(- self.sigma_min)) | |
| sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) | |
| return (sigma_t - self.sigma_min) / ( | |
| self.sigma_max - self.sigma_min) | |
| class GeometricNoise(Noise): | |
| def __init__(self, sigma_min=1e-3, sigma_max=1): | |
| super().__init__() | |
| self.sigmas = 1.0 * torch.tensor([sigma_min, sigma_max]) | |
| def rate_noise(self, t): | |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t * ( | |
| self.sigmas[1].log() - self.sigmas[0].log()) | |
| def total_noise(self, t): | |
| return self.sigmas[0] ** (1 - t) * self.sigmas[1] ** t | |
| class LogLinearNoise(Noise): | |
| """Log Linear noise schedule. | |
| Built such that 1 - 1/e^(n(t)) interpolates between 0 and 1. | |
| """ | |
| def __init__(self, eps=1e-3): | |
| super().__init__() | |
| self.eps = eps | |
| self.sigma_max = self.total_noise(torch.tensor(1.0)) | |
| self.sigma_min = self.eps + self.total_noise(torch.tensor(0.0)) | |
| def rate_noise(self, t): | |
| return (1 - self.eps) / (1 - (1 - self.eps) * t) | |
| def total_noise(self, t): | |
| return -torch.log1p(-(1 - self.eps) * t) | |
| def importance_sampling_transformation(self, t): | |
| f_T = torch.log1p(- torch.exp(- self.sigma_max)) | |
| f_0 = torch.log1p(- torch.exp(- self.sigma_min)) | |
| sigma_t = - torch.log1p(- torch.exp(t * f_T + (1 - t) * f_0)) | |
| t = - torch.expm1(- sigma_t) / (1 - self.eps) | |
| return t | |