""" wild mixture of https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py https://github.com/CompVis/taming-transformers -- merci """ import torch import torch.nn as nn import numpy as np from contextlib import contextmanager from functools import partial from refnet.util import default, count_params, instantiate_from_config, exists from refnet.ldm.util import make_beta_schedule, extract_into_tensor def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode does not change anymore.""" return self def uniform_on_device(r1, r2, shape, device): return (r1 - r2) * torch.rand(*shape, device=device) + r2 def rescale_zero_terminal_snr(betas): """ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1) Args: betas (`torch.FloatTensor`): the betas that the scheduler is being initialized with. Returns: `torch.FloatTensor`: rescaled betas with zero terminal SNR """ # Convert betas to alphas_bar_sqrt alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, dim=0) alphas_bar_sqrt = alphas_cumprod.sqrt() # Store old values. alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone() alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone() # Shift so the last timestep is zero. alphas_bar_sqrt -= alphas_bar_sqrt_T # Scale so the first timestep is back to the old value. alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T) # Convert alphas_bar_sqrt to betas alphas_bar = alphas_bar_sqrt**2 # Revert sqrt alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod alphas = torch.cat([alphas_bar[0:1], alphas]) betas = 1 - alphas return betas class DDPM(nn.Module): # classic DDPM with Gaussian diffusion, in image space def __init__( self, unet_config, timesteps = 1000, beta_schedule = "scaled_linear", image_size = 256, channels = 3, linear_start = 1e-4, linear_end = 2e-2, cosine_s = 8e-3, v_posterior = 0., # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta parameterization = "eps", # all assuming fixed variance schedules zero_snr = False, half_precision_dtype = "float16", version = "sdv1", *args, **kwargs ): super().__init__() assert parameterization in ["eps", "v"], "currently only supporting 'eps' and 'v'" assert half_precision_dtype in ["float16", "bfloat16"], "K-diffusion samplers do not support bfloat16, use float16 by default" if zero_snr: assert parameterization == "v", 'Zero SNR is only available for "v-prediction" model.' self.is_sdxl = (version == "sdxl") self.parameterization = parameterization print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") self.cond_stage_model = None self.img_embedder = None self.image_size = image_size # try conv? self.channels = channels self.model = DiffusionWrapper(unet_config) count_params(self.model, verbose=True) self.v_posterior = v_posterior self.half_precision_dtype = torch.bfloat16 if half_precision_dtype == "bfloat16" else torch.float16 self.register_schedule(beta_schedule=beta_schedule, timesteps=timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr) def register_schedule(self, beta_schedule="scaled_linear", timesteps=1000, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3, zero_snr=False): betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end, cosine_s=cosine_s, zero_snr=zero_snr) alphas = 1. - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) timesteps, = betas.shape self.num_timesteps = int(timesteps) self.linear_start = linear_start self.linear_end = linear_end assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' to_torch = partial(torch.tensor, dtype=torch.float32) self.register_buffer('betas', to_torch(betas)) self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # calculations for diffusion q(x_t | x_{t-1}) and others self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) # calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = (1 - self.v_posterior) * betas * (1. - alphas_cumprod_prev) / ( 1. - alphas_cumprod) + self.v_posterior * betas # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) self.register_buffer('posterior_variance', to_torch(posterior_variance)) # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) self.register_buffer('posterior_mean_coef1', to_torch( betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) self.register_buffer('posterior_mean_coef2', to_torch( (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) @contextmanager def ema_scope(self, context=None): if self.use_ema: self.model_ema.store(self.model.parameters()) self.model_ema.copy_to(self.model) if context is not None: print(f"{context}: Switched to EMA weights") try: yield None finally: if self.use_ema: self.model_ema.restore(self.model.parameters()) if context is not None: print(f"{context}: Restored training weights") def predict_start_from_z_and_v(self, x_t, t, v): # self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) # self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v ) def add_noise(self, x_start, t, noise=None): noise = default(noise, lambda: torch.randn_like(x_start)) return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise).to(x_start.dtype) def get_v(self, x, noise, t): return ( extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x ) def normalize_timesteps(self, timesteps): return timesteps class LatentDiffusion(DDPM): """main class""" def __init__( self, first_stage_config, cond_stage_config, scale_factor = 1.0, *args, **kwargs ): super().__init__(*args, **kwargs) self.scale_factor = scale_factor self.first_stage_model, self.cond_stage_model = map( lambda t: instantiate_from_config(t).eval().requires_grad_(False), (first_stage_config, cond_stage_config) ) @torch.no_grad() def get_first_stage_encoding(self, x): encoder_posterior = self.first_stage_model.encode(x) z = encoder_posterior.sample() * self.scale_factor return z.to(self.dtype).detach() @torch.no_grad() def decode_first_stage(self, z): z = 1. / self.scale_factor * z return self.first_stage_model.decode(z.to(self.first_stage_model.dtype)).detach() def apply_model(self, x_noisy, t, cond): return self.model(x_noisy, t, **cond) def get_learned_embedding(self, c, *args, **kwargs): wd_emb, wd_logits = map(lambda t: t.detach() if exists(t) else None, self.img_embedder.encode(c, **kwargs)) clip_emb = self.cond_stage_model.encode(c, **kwargs).detach() return wd_emb, wd_logits, clip_emb class DiffusionWrapper(nn.Module): def __init__(self, diff_model_config): super().__init__() self.diffusion_model = instantiate_from_config(diff_model_config) def forward(self, x, t, **cond): for k in cond: if k in ["context", "y", "concat"]: cond[k] = torch.cat(cond[k], 1) out = self.diffusion_model(x, t, **cond) return out