Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import torch | |
| #---------------------------------------------------------------------------- | |
| # Preconditioning corresponding to the variance exploding (VE) formulation | |
| # from the paper "Score-Based Generative Modeling through Stochastic | |
| # Differential Equations". | |
| class VEPrecond(torch.nn.Module): | |
| def __init__(self, | |
| model, | |
| label_dim = 0, # Number of class labels, 0 = unconditional. | |
| use_fp16 = False, # Execute the underlying model at FP16 precision? | |
| sigma_min = 0.02, # Minimum supported noise level. | |
| sigma_max = 100, # Maximum supported noise level. | |
| ): | |
| super().__init__() | |
| self.label_dim = label_dim | |
| self.use_fp16 = use_fp16 | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.model = model | |
| def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): | |
| sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) | |
| x = x.to(torch.float32) | |
| class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) | |
| dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 | |
| c_skip = 1 | |
| c_out = sigma | |
| c_in = 1 | |
| c_noise = (0.5 * sigma).log() | |
| if class_labels is not None: | |
| F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) | |
| else: | |
| F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) | |
| assert F_x.dtype == dtype | |
| D_x = c_skip * x + c_out * F_x.to(torch.float32) | |
| return D_x | |
| def round_sigma(self, sigma): | |
| return torch.as_tensor(sigma) | |
| #---------------------------------------------------------------------------- | |
| # Preconditioning corresponding to improved DDPM (iDDPM) formulation from | |
| # the paper "Improved Denoising Diffusion Probabilistic Models". | |
| class iDDPMPrecond(torch.nn.Module): | |
| def __init__(self, | |
| model, | |
| label_dim = 0, # Number of class labels, 0 = unconditional. | |
| use_fp16 = False, # Execute the underlying model at FP16 precision? | |
| C_1 = 0.001, # Timestep adjustment at low noise levels. | |
| C_2 = 0.008, # Timestep adjustment at high noise levels. | |
| M = 1000, # Original number of timesteps in the DDPM formulation. | |
| ): | |
| super().__init__() | |
| self.label_dim = label_dim | |
| self.use_fp16 = use_fp16 | |
| self.C_1 = C_1 | |
| self.C_2 = C_2 | |
| self.M = M | |
| self.model = model | |
| u = torch.zeros(M + 1) | |
| for j in range(M, 0, -1): # M, ..., 1 | |
| u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt() | |
| self.register_buffer('u', u) | |
| self.sigma_min = float(u[M - 1]) | |
| self.sigma_max = float(u[0]) | |
| def forward(self, x, sigma, class_labels=None, lamb=None, force_fp32=False, **model_kwargs): | |
| sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) | |
| x = x.to(torch.float32) | |
| class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim) | |
| dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 | |
| c_skip = 1 | |
| c_out = -sigma | |
| c_in = 1 / (sigma ** 2 + 1).sqrt() | |
| c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) | |
| # if class_labels is not None: | |
| # F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) | |
| # else: | |
| if lamb is not None: | |
| F_x = self.model((c_in * x).to(dtype), lamb, c_noise.flatten(), **model_kwargs) | |
| else: | |
| F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) | |
| assert F_x.dtype == dtype | |
| D_x = c_skip * x + c_out * F_x.to(torch.float32) | |
| return D_x | |
| def alpha_bar(self, j): | |
| j = torch.as_tensor(j) | |
| return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 | |
| def round_sigma(self, sigma, return_index=False): | |
| sigma = torch.as_tensor(sigma) | |
| index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) | |
| result = index if return_index else self.u[index.flatten()].to(sigma.dtype) | |
| return result.reshape(sigma.shape).to(sigma.device) | |
| #---------------------------------------------------------------------------- | |
| # Improved preconditioning proposed in the paper "Elucidating the Design | |
| # Space of Diffusion-Based Generative Models" (EDM). | |
| class EDMPrecond(torch.nn.Module): | |
| def __init__(self, | |
| model, | |
| label_dim = 0, # Number of class labels, 0 = unconditional. | |
| use_fp16 = False, # Execute the underlying model at FP16 precision? | |
| sigma_min = 0, # Minimum supported noise level. | |
| sigma_max = float('inf'), # Maximum supported noise level. | |
| sigma_data = 0.5, # Expected standard deviation of the training data. | |
| ): | |
| super().__init__() | |
| self.label_dim = label_dim | |
| self.use_fp16 = use_fp16 | |
| self.sigma_min = sigma_min | |
| self.sigma_max = sigma_max | |
| self.sigma_data = sigma_data | |
| self.model = model | |
| def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs): | |
| x = x.to(torch.float32) | |
| sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1) | |
| if class_labels is not None: | |
| if self.label_dim == 0: | |
| class_labels = None | |
| else: | |
| class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim) | |
| dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32 | |
| c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) | |
| c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() | |
| c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() | |
| c_in = c_in.to(x.device) | |
| c_noise = sigma.log() / 4 | |
| if class_labels is not None: | |
| F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), c_latent=class_labels, **model_kwargs) | |
| else: | |
| F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs) | |
| assert F_x.dtype == dtype | |
| D_x = c_skip * x + c_out * F_x.to(torch.float32) | |
| return D_x | |
| def round_sigma(self, sigma): | |
| return torch.as_tensor(sigma) | |
| #---------------------------------------------------------------------------- |