Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,081 Bytes
e321b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import numpy as np
import torch
#----------------------------------------------------------------------------
# Preconditioning corresponding to the variance exploding (VE) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".
class VEPrecond(torch.nn.Module):
def __init__(self,
model,
label_dim = 0, # Number of class labels, 0 = unconditional.
use_fp16 = False, # Execute the underlying model at FP16 precision?
sigma_min = 0.02, # Minimum supported noise level.
sigma_max = 100, # Maximum supported noise level.
):
super().__init__()
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.model = model
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
x = x.to(torch.float32)
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = 1
c_out = sigma
c_in = 1
c_noise = (0.5 * sigma).log()
if class_labels is not None:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
else:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
#----------------------------------------------------------------------------
# Preconditioning corresponding to improved DDPM (iDDPM) formulation from
# the paper "Improved Denoising Diffusion Probabilistic Models".
class iDDPMPrecond(torch.nn.Module):
def __init__(self,
model,
label_dim = 0, # Number of class labels, 0 = unconditional.
use_fp16 = False, # Execute the underlying model at FP16 precision?
C_1 = 0.001, # Timestep adjustment at low noise levels.
C_2 = 0.008, # Timestep adjustment at high noise levels.
M = 1000, # Original number of timesteps in the DDPM formulation.
):
super().__init__()
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.C_1 = C_1
self.C_2 = C_2
self.M = M
self.model = model
u = torch.zeros(M + 1)
for j in range(M, 0, -1): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt()
self.register_buffer('u', u)
self.sigma_min = float(u[M - 1])
self.sigma_max = float(u[0])
def forward(self, x, sigma, class_labels=None, lamb=None, force_fp32=False, **model_kwargs):
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
x = x.to(torch.float32)
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = 1
c_out = -sigma
c_in = 1 / (sigma ** 2 + 1).sqrt()
c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
# if class_labels is not None:
# F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
# else:
if lamb is not None:
F_x = self.model((c_in * x).to(dtype), lamb, c_noise.flatten(), **model_kwargs)
else:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def alpha_bar(self, j):
j = torch.as_tensor(j)
return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
def round_sigma(self, sigma, return_index=False):
sigma = torch.as_tensor(sigma)
index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
return result.reshape(sigma.shape).to(sigma.device)
#----------------------------------------------------------------------------
# Improved preconditioning proposed in the paper "Elucidating the Design
# Space of Diffusion-Based Generative Models" (EDM).
class EDMPrecond(torch.nn.Module):
def __init__(self,
model,
label_dim = 0, # Number of class labels, 0 = unconditional.
use_fp16 = False, # Execute the underlying model at FP16 precision?
sigma_min = 0, # Minimum supported noise level.
sigma_max = float('inf'), # Maximum supported noise level.
sigma_data = 0.5, # Expected standard deviation of the training data.
):
super().__init__()
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.model = model
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
if class_labels is not None:
if self.label_dim == 0:
class_labels = None
else:
class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_in = c_in.to(x.device)
c_noise = sigma.log() / 4
if class_labels is not None:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), c_latent=class_labels, **model_kwargs)
else:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
#---------------------------------------------------------------------------- |