Spaces:
Runtime error
Runtime error
| from typing import List | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from .modules.diffusionmodules.util import ( | |
| make_beta_schedule, | |
| extract_into_tensor, | |
| enforce_zero_terminal_snr, | |
| noise_like, | |
| ) | |
| from .util import exists, default, instantiate_from_config | |
| from .modules.distributions.distributions import DiagonalGaussianDistribution | |
| class DiffusionWrapper(nn.Module): | |
| def __init__(self, diffusion_model): | |
| super().__init__() | |
| self.diffusion_model = diffusion_model | |
| def forward(self, *args, **kwargs): | |
| return self.diffusion_model(*args, **kwargs) | |
| class LatentDiffusionInterface(nn.Module): | |
| """a simple interface class for LDM inference""" | |
| def __init__( | |
| self, | |
| unet_config, | |
| clip_config, | |
| vae_config, | |
| parameterization="eps", | |
| scale_factor=0.18215, | |
| beta_schedule="linear", | |
| timesteps=1000, | |
| linear_start=0.00085, | |
| linear_end=0.0120, | |
| cosine_s=8e-3, | |
| given_betas=None, | |
| zero_snr=False, | |
| *args, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| unet = instantiate_from_config(unet_config) | |
| self.model = DiffusionWrapper(unet) | |
| self.clip_model = instantiate_from_config(clip_config) | |
| self.vae_model = instantiate_from_config(vae_config) | |
| self.parameterization = parameterization | |
| self.scale_factor = scale_factor | |
| self.register_schedule( | |
| given_betas=given_betas, | |
| beta_schedule=beta_schedule, | |
| timesteps=timesteps, | |
| linear_start=linear_start, | |
| linear_end=linear_end, | |
| cosine_s=cosine_s, | |
| zero_snr=zero_snr | |
| ) | |
| def register_schedule( | |
| self, | |
| given_betas=None, | |
| beta_schedule="linear", | |
| timesteps=1000, | |
| linear_start=1e-4, | |
| linear_end=2e-2, | |
| cosine_s=8e-3, | |
| zero_snr=False | |
| ): | |
| if exists(given_betas): | |
| betas = given_betas | |
| else: | |
| betas = make_beta_schedule( | |
| beta_schedule, | |
| timesteps, | |
| linear_start=linear_start, | |
| linear_end=linear_end, | |
| cosine_s=cosine_s, | |
| ) | |
| if zero_snr: | |
| print("--- using zero snr---") | |
| betas = enforce_zero_terminal_snr(betas).numpy() | |
| alphas = 1.0 - betas | |
| alphas_cumprod = np.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) | |
| (timesteps,) = betas.shape | |
| self.num_timesteps = int(timesteps) | |
| self.linear_start = linear_start | |
| self.linear_end = linear_end | |
| assert ( | |
| alphas_cumprod.shape[0] == self.num_timesteps | |
| ), "alphas have to be defined for each timestep" | |
| to_torch = partial(torch.tensor, dtype=torch.float32) | |
| self.register_buffer("betas", to_torch(betas)) | |
| self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) | |
| self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) | |
| self.register_buffer( | |
| "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) | |
| ) | |
| self.register_buffer( | |
| "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) | |
| ) | |
| eps = 1e-8 # adding small epsilon value to avoid devide by zero error | |
| self.register_buffer( | |
| "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps))) | |
| ) | |
| self.register_buffer( | |
| "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / (alphas_cumprod + eps) - 1)) | |
| ) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| self.v_posterior = 0 | |
| posterior_variance = (1 - self.v_posterior) * betas * ( | |
| 1.0 - alphas_cumprod_prev | |
| ) / (1.0 - alphas_cumprod) + self.v_posterior * betas | |
| # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
| self.register_buffer("posterior_variance", to_torch(posterior_variance)) | |
| # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
| self.register_buffer( | |
| "posterior_log_variance_clipped", | |
| to_torch(np.log(np.maximum(posterior_variance, 1e-20))), | |
| ) | |
| self.register_buffer( | |
| "posterior_mean_coef1", | |
| to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), | |
| ) | |
| self.register_buffer( | |
| "posterior_mean_coef2", | |
| to_torch( | |
| (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) | |
| ), | |
| ) | |
| def q_sample(self, x_start, t, noise=None): | |
| noise = default(noise, lambda: torch.randn_like(x_start)) | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start | |
| + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) | |
| * noise | |
| ) | |
| def get_v(self, x, noise, t): | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * noise | |
| - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x | |
| ) | |
| def predict_start_from_noise(self, x_t, t, noise): | |
| return ( | |
| extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) | |
| * noise | |
| ) | |
| def predict_start_from_z_and_v(self, x_t, t, v): | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t | |
| - extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v | |
| ) | |
| def predict_eps_from_z_and_v(self, x_t, t, v): | |
| return ( | |
| extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v | |
| + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) | |
| * x_t | |
| ) | |
| def apply_model(self, x_noisy, t, cond, **kwargs): | |
| assert isinstance(cond, dict), "cond has to be a dictionary" | |
| return self.model(x_noisy, t, **cond, **kwargs) | |
| def get_learned_conditioning(self, prompts: List[str]): | |
| return self.clip_model(prompts) | |
| def get_learned_image_conditioning(self, images): | |
| return self.clip_model.forward_image(images) | |
| def get_first_stage_encoding(self, encoder_posterior): | |
| if isinstance(encoder_posterior, DiagonalGaussianDistribution): | |
| z = encoder_posterior.sample() | |
| elif isinstance(encoder_posterior, torch.Tensor): | |
| z = encoder_posterior | |
| else: | |
| raise NotImplementedError( | |
| f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" | |
| ) | |
| return self.scale_factor * z | |
| def encode_first_stage(self, x): | |
| return self.vae_model.encode(x) | |
| def decode_first_stage(self, z): | |
| z = 1.0 / self.scale_factor * z | |
| return self.vae_model.decode(z) | |