tellurion's picture
initialize huggingface space demo
d066167
"""
wild mixture of
https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py
https://github.com/CompVis/taming-transformers
-- merci
"""
import torch
import torch.nn as nn
import numpy as np
from contextlib import contextmanager
from functools import partial
from refnet.util import default, count_params, instantiate_from_config, exists
from refnet.ldm.util import make_beta_schedule, extract_into_tensor
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def uniform_on_device(r1, r2, shape, device):
return (r1 - r2) * torch.rand(*shape, device=device) + r2
def rescale_zero_terminal_snr(betas):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.FloatTensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.FloatTensor`: rescaled betas with zero terminal SNR
"""
# Convert betas to alphas_bar_sqrt
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
alphas = torch.cat([alphas_bar[0:1], alphas])
betas = 1 - alphas
return betas
class DDPM(nn.Module):
# classic DDPM with Gaussian diffusion, in image space
def __init__(
self,
unet_config,
timesteps = 1000,
beta_schedule = "scaled_linear",
image_size = 256,
channels = 3,
linear_start = 1e-4,
linear_end = 2e-2,
cosine_s = 8e-3,
v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta
parameterization = "eps", # all assuming fixed variance schedules
zero_snr = False,
half_precision_dtype = "float16",
version = "sdv1",
*args,
**kwargs
):
super().__init__()
assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'"
assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default"
if zero_snr:
assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.'
self.is_sdxl = (version == "sdxl")
self.parameterization = parameterization
print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode")
self.cond_stage_model = None
self.img_embedder = None
self.image_size = image_size # try conv?
self.channels = channels
self.model = DiffusionWrapper(unet_config)
count_params(self.model, verbose=True)
self.v_posterior = v_posterior
self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16
self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps,
linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr)
def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s, zero_snr=zero_snr)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / (
1. - alphas_cumprod) + self.v_posterior * betas
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
self.register_buffer('posterior_variance', to_torch(posterior_variance))
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
self.register_buffer('posterior_mean_coef1', to_torch(
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
self.register_buffer('posterior_mean_coef2', to_torch(
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.model.parameters())
self.model_ema.copy_to(self.model)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.model.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def predict_start_from_z_and_v(self, x_t, t, v):
# self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
# self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def add_noise(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype)
def get_v(self, x, noise, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)
def normalize_timesteps(self, timesteps):
return timesteps
class LatentDiffusion(DDPM):
"""main class"""
def __init__(
self,
first_stage_config,
cond_stage_config,
scale_factor = 1.0,
*args,
**kwargs
):
super().__init__(*args, **kwargs)
self.scale_factor = scale_factor
self.first_stage_model, self.cond_stage_model = map(
lambda t: instantiate_from_config(t).eval().requires_grad_(False),
(first_stage_config, cond_stage_config)
)
@torch.no_grad()
def get_first_stage_encoding(self, x):
encoder_posterior = self.first_stage_model.encode(x)
z = encoder_posterior.sample() * self.scale_factor
return z.to(self.dtype).detach()
@torch.no_grad()
def decode_first_stage(self, z):
z = 1. / self.scale_factor * z
return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach()
def apply_model(self, x_noisy, t, cond):
return self.model(x_noisy, t, **cond)
def get_learned_embedding(self, c, *args, **kwargs):
wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs))
clip_emb = self.cond_stage_model.encode(c, **kwargs).detach()
return wd_emb, wd_logits, clip_emb
class DiffusionWrapper(nn.Module):
def __init__(self, diff_model_config):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
def forward(self, x, t, **cond):
for k in cond:
if k in ["context", "y", "concat"]:
cond[k] = torch.cat(cond[k], 1)
out = self.diffusion_model(x, t, **cond)
return out