daidedou
forgot a few things lol
e321b92
raw
history blame
7.08 kB
import numpy as np
import torch
#----------------------------------------------------------------------------
# Preconditioning corresponding to the variance exploding (VE) formulation
# from the paper "Score-Based Generative Modeling through Stochastic
# Differential Equations".
class VEPrecond(torch.nn.Module):
def __init__(self,
model,
label_dim = 0, # Number of class labels, 0 = unconditional.
use_fp16 = False, # Execute the underlying model at FP16 precision?
sigma_min = 0.02, # Minimum supported noise level.
sigma_max = 100, # Maximum supported noise level.
):
super().__init__()
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.model = model
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
x = x.to(torch.float32)
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = 1
c_out = sigma
c_in = 1
c_noise = (0.5 * sigma).log()
if class_labels is not None:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
else:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
#----------------------------------------------------------------------------
# Preconditioning corresponding to improved DDPM (iDDPM) formulation from
# the paper "Improved Denoising Diffusion Probabilistic Models".
class iDDPMPrecond(torch.nn.Module):
def __init__(self,
model,
label_dim = 0, # Number of class labels, 0 = unconditional.
use_fp16 = False, # Execute the underlying model at FP16 precision?
C_1 = 0.001, # Timestep adjustment at low noise levels.
C_2 = 0.008, # Timestep adjustment at high noise levels.
M = 1000, # Original number of timesteps in the DDPM formulation.
):
super().__init__()
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.C_1 = C_1
self.C_2 = C_2
self.M = M
self.model = model
u = torch.zeros(M + 1)
for j in range(M, 0, -1): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=C_1) - 1).sqrt()
self.register_buffer('u', u)
self.sigma_min = float(u[M - 1])
self.sigma_max = float(u[0])
def forward(self, x, sigma, class_labels=None, lamb=None, force_fp32=False, **model_kwargs):
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
x = x.to(torch.float32)
class_labels = None if self.label_dim == 0 else torch.zeros([1, self.label_dim], device=x.device) if class_labels is None else class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = 1
c_out = -sigma
c_in = 1 / (sigma ** 2 + 1).sqrt()
c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
# if class_labels is not None:
# F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
# else:
if lamb is not None:
F_x = self.model((c_in * x).to(dtype), lamb, c_noise.flatten(), **model_kwargs)
else:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def alpha_bar(self, j):
j = torch.as_tensor(j)
return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2
def round_sigma(self, sigma, return_index=False):
sigma = torch.as_tensor(sigma)
index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
return result.reshape(sigma.shape).to(sigma.device)
#----------------------------------------------------------------------------
# Improved preconditioning proposed in the paper "Elucidating the Design
# Space of Diffusion-Based Generative Models" (EDM).
class EDMPrecond(torch.nn.Module):
def __init__(self,
model,
label_dim = 0, # Number of class labels, 0 = unconditional.
use_fp16 = False, # Execute the underlying model at FP16 precision?
sigma_min = 0, # Minimum supported noise level.
sigma_max = float('inf'), # Maximum supported noise level.
sigma_data = 0.5, # Expected standard deviation of the training data.
):
super().__init__()
self.label_dim = label_dim
self.use_fp16 = use_fp16
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.model = model
def forward(self, x, sigma, class_labels=None, force_fp32=False, **model_kwargs):
x = x.to(torch.float32)
sigma = sigma.to(torch.float32).reshape(-1, 1, 1, 1)
if class_labels is not None:
if self.label_dim == 0:
class_labels = None
else:
class_labels = class_labels.to(torch.float32).reshape(-1, self.label_dim)
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_in = c_in.to(x.device)
c_noise = sigma.log() / 4
if class_labels is not None:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), c_latent=class_labels, **model_kwargs)
else:
F_x = self.model((c_in * x).to(dtype), c_noise.flatten(), **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
def round_sigma(self, sigma):
return torch.as_tensor(sigma)
#----------------------------------------------------------------------------