| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| import numpy as np |
| import torch |
|
|
| import tqdm |
| from diffusers import ClassifierFreeGuidanceScheduler, GlideDDIMScheduler, CLIPTextModel, DiffusionPipeline, GLIDETextToImageUNetModel, GLIDESuperResUNetModel |
| from transformers import GPT2Tokenizer |
|
|
|
|
| def _extract_into_tensor(arr, timesteps, broadcast_shape): |
| """ |
| Extract values from a 1-D numpy array for a batch of indices. |
| |
| :param arr: the 1-D numpy array. |
| :param timesteps: a tensor of indices into the array to extract. |
| :param broadcast_shape: a larger shape of K dimensions with the batch |
| dimension equal to the length of timesteps. |
| :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. |
| """ |
| res = torch.from_numpy(arr).to(device=timesteps.device)[timesteps].float() |
| while len(res.shape) < len(broadcast_shape): |
| res = res[..., None] |
| return res + torch.zeros(broadcast_shape, device=timesteps.device) |
|
|
|
|
| class GLIDE(DiffusionPipeline): |
| def __init__( |
| self, |
| text_unet: GLIDETextToImageUNetModel, |
| text_noise_scheduler: ClassifierFreeGuidanceScheduler, |
| text_encoder: CLIPTextModel, |
| tokenizer: GPT2Tokenizer, |
| upscale_unet: GLIDESuperResUNetModel, |
| upscale_noise_scheduler: GlideDDIMScheduler |
| ): |
| super().__init__() |
| self.register_modules( |
| text_unet=text_unet, text_noise_scheduler=text_noise_scheduler, text_encoder=text_encoder, tokenizer=tokenizer, |
| upscale_unet=upscale_unet, upscale_noise_scheduler=upscale_noise_scheduler |
| ) |
|
|
| def q_posterior_mean_variance(self, scheduler, x_start, x_t, t): |
| """ |
| Compute the mean and variance of the diffusion posterior: |
| |
| q(x_{t-1} | x_t, x_0) |
| |
| """ |
| assert x_start.shape == x_t.shape |
| posterior_mean = ( |
| _extract_into_tensor(scheduler.posterior_mean_coef1, t, x_t.shape) * x_start |
| + _extract_into_tensor(scheduler.posterior_mean_coef2, t, x_t.shape) * x_t |
| ) |
| posterior_variance = _extract_into_tensor(scheduler.posterior_variance, t, x_t.shape) |
| posterior_log_variance_clipped = _extract_into_tensor( |
| scheduler.posterior_log_variance_clipped, t, x_t.shape |
| ) |
| assert ( |
| posterior_mean.shape[0] |
| == posterior_variance.shape[0] |
| == posterior_log_variance_clipped.shape[0] |
| == x_start.shape[0] |
| ) |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped |
|
|
| def p_mean_variance(self, model, scheduler, x, t, transformer_out=None, low_res=None, clip_denoised=True): |
| """ |
| Apply the model to get p(x_{t-1} | x_t), as well as a prediction of |
| the initial x, x_0. |
| |
| :param model: the model, which takes a signal and a batch of timesteps |
| as input. |
| :param x: the [N x C x ...] tensor at time t. |
| :param t: a 1-D Tensor of timesteps. |
| :param clip_denoised: if True, clip the denoised signal into [-1, 1]. |
| :param model_kwargs: if not None, a dict of extra keyword arguments to |
| pass to the model. This can be used for conditioning. |
| :return: a dict with the following keys: |
| - 'mean': the model mean output. |
| - 'variance': the model variance output. |
| - 'log_variance': the log of 'variance'. |
| - 'pred_xstart': the prediction for x_0. |
| """ |
|
|
| B, C = x.shape[:2] |
| assert t.shape == (B,) |
| if transformer_out is None: |
| |
| model_output = model(x, t, low_res) |
| else: |
| |
| model_output = model(x, t, transformer_out) |
|
|
| assert model_output.shape == (B, C * 2, *x.shape[2:]) |
| model_output, model_var_values = torch.split(model_output, C, dim=1) |
| min_log = _extract_into_tensor(scheduler.posterior_log_variance_clipped, t, x.shape) |
| max_log = _extract_into_tensor(np.log(scheduler.betas), t, x.shape) |
| |
| frac = (model_var_values + 1) / 2 |
| model_log_variance = frac * max_log + (1 - frac) * min_log |
| model_variance = torch.exp(model_log_variance) |
|
|
| pred_xstart = self._predict_xstart_from_eps(scheduler, x_t=x, t=t, eps=model_output) |
| if clip_denoised: |
| pred_xstart = pred_xstart.clamp(-1, 1) |
| model_mean, _, _ = self.q_posterior_mean_variance(scheduler, x_start=pred_xstart, x_t=x, t=t) |
|
|
| assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape |
| return model_mean, model_variance, model_log_variance, pred_xstart |
|
|
| def _predict_xstart_from_eps(self, scheduler, x_t, t, eps): |
| assert x_t.shape == eps.shape |
| return ( |
| _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t |
| - _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps |
| ) |
|
|
| def _predict_eps_from_xstart(self, scheduler, x_t, t, pred_xstart): |
| return ( |
| _extract_into_tensor(scheduler.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart |
| ) / _extract_into_tensor(scheduler.sqrt_recipm1_alphas_cumprod, t, x_t.shape) |
|
|
| @torch.no_grad() |
| def __call__(self, prompt, generator=None, torch_device=None): |
| torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| self.text_unet.to(torch_device) |
| self.text_encoder.to(torch_device) |
| self.upscale_unet.to(torch_device) |
|
|
| |
| guidance_scale = 3.0 |
|
|
| def text_model_fn(x_t, ts, transformer_out, **kwargs): |
| half = x_t[: len(x_t) // 2] |
| combined = torch.cat([half, half], dim=0) |
| model_out = self.text_unet(combined, ts, transformer_out, **kwargs) |
| eps, rest = model_out[:, :3], model_out[:, 3:] |
| cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) |
| half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps) |
| eps = torch.cat([half_eps, half_eps], dim=0) |
| return torch.cat([eps, rest], dim=1) |
|
|
| |
| batch_size = 2 |
| image = self.text_noise_scheduler.sample_noise( |
| (batch_size, self.text_unet.in_channels, 64, 64), device=torch_device, generator=generator |
| ) |
|
|
| |
| |
| inputs = self.tokenizer([prompt, ""], padding="max_length", max_length=128, return_tensors="pt") |
| input_ids = inputs["input_ids"].to(torch_device) |
| attention_mask = inputs["attention_mask"].to(torch_device) |
| transformer_out = self.text_encoder(input_ids, attention_mask).last_hidden_state |
|
|
| |
| num_timesteps = len(self.text_noise_scheduler) |
| for i in tqdm.tqdm(reversed(range(num_timesteps)), total=num_timesteps): |
| t = torch.tensor([i] * image.shape[0], device=torch_device) |
| mean, variance, log_variance, pred_xstart = self.p_mean_variance( |
| text_model_fn, self.text_noise_scheduler, image, t, transformer_out=transformer_out |
| ) |
| noise = self.text_noise_scheduler.sample_noise(image.shape, device=torch_device, generator=generator) |
| nonzero_mask = (t != 0).float().view(-1, *([1] * (len(image.shape) - 1))) |
| image = mean + nonzero_mask * torch.exp(0.5 * log_variance) * noise |
|
|
| |
| batch_size = 1 |
| image = image[:1] |
| low_res = ((image + 1) * 127.5).round() / 127.5 - 1 |
| eta = 0.0 |
|
|
| |
| |
| upsample_temp = 0.997 |
|
|
| image = self.upscale_noise_scheduler.sample_noise( |
| (batch_size, 3, 256, 256), device=torch_device, generator=generator |
| ) * upsample_temp |
|
|
| num_timesteps = len(self.upscale_noise_scheduler) |
| for t in tqdm.tqdm(reversed(range(len(self.upscale_noise_scheduler))), total=len(self.upscale_noise_scheduler)): |
| |
| clipped_image_coeff = 1 / torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t)) |
| clipped_noise_coeff = torch.sqrt(1 / self.upscale_noise_scheduler.get_alpha_prod(t) - 1) |
| image_coeff = (1 - self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * torch.sqrt( |
| self.upscale_noise_scheduler.get_alpha(t)) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) |
| clipped_coeff = torch.sqrt(self.upscale_noise_scheduler.get_alpha_prod(t - 1)) * self.upscale_noise_scheduler.get_beta( |
| t) / (1 - self.upscale_noise_scheduler.get_alpha_prod(t)) |
|
|
| |
| time_input = torch.tensor([t] * image.shape[0], device=torch_device) |
| model_output = self.upscale_unet(image, time_input, low_res) |
| noise_residual, pred_variance = torch.split(model_output, 3, dim=1) |
|
|
| |
| |
| pred_mean = clipped_image_coeff * image - clipped_noise_coeff * noise_residual |
| pred_mean = torch.clamp(pred_mean, -1, 1) |
| prev_image = clipped_coeff * pred_mean + image_coeff * image |
|
|
| |
| prev_variance = self.upscale_noise_scheduler.sample_variance(t, prev_image.shape, device=torch_device, |
| generator=generator) |
|
|
| |
| sampled_prev_image = prev_image + prev_variance |
| image = sampled_prev_image |
|
|
| image = image.permute(0, 2, 3, 1) |
|
|
| return image |
|
|