from typing import Optional from torch import Tensor,device from tqdm import tqdm import torch from collections import deque from TorchJaekwon.Util.UtilTorch import UtilTorch from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil class PNDM: #Pseudo Numerical methods for Diffusion Models on manifolds (PNDM) is by Luping Liu, Yi Ren, Zhijie Lin and Zhou Zhao def __init__(self, ddpm_module:DDPM) -> None: self.ddpm_module = ddpm_module @torch.no_grad() def infer(self, x_shape:Optional[tuple], cond:Optional[dict] = None, is_cond_unpack:bool = False, pndm_speedup:int = 10) -> Tensor: _, cond, additional_data_dict = self.ddpm_module.preprocess(x_start = None, cond=cond) if x_shape is None: x_shape = self.ddpm_module.get_x_shape(cond=cond) total_timesteps:int = self.ddpm_module.timesteps model_device:device = UtilTorch.get_model_device(self.ddpm_module) x:Tensor = torch.randn(x_shape, device = model_device) self.noise_list = deque(maxlen=4) for i in tqdm(reversed(range(0, total_timesteps, pndm_speedup)), desc='sample time step', total=total_timesteps // pndm_speedup): x = self.p_sample_plms(x, torch.full((x_shape[0],), i, device=model_device, dtype=torch.long), pndm_speedup, cond, is_cond_unpack) return self.ddpm_module.postprocess(x, additional_data_dict) @torch.no_grad() def p_sample_plms(self, x, t, interval, cond, is_cond_unpack): """ Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). """ noise_list = self.noise_list noise_pred = self.ddpm_module.apply_model(x, t, cond, is_cond_unpack, self.ddpm_module.cfg_scale) if self.ddpm_module.model_output_type == 'v_prediction': noise_pred = self.ddpm_module.predict_noise_from_v(x, t, noise_pred) if len(noise_list) == 0: x_pred = self.get_x_pred(x, noise_pred, t, interval) noise_pred_prev = self.ddpm_module.apply_model(x_pred, torch.max(t-interval, torch.zeros_like(t)), cond, is_cond_unpack) #max(t-interval, 0) noise_pred_prime = (noise_pred + noise_pred_prev) / 2 elif len(noise_list) == 1: noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 elif len(noise_list) == 2: noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 elif len(noise_list) >= 3: noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 x_prev = self.get_x_pred(x, noise_pred_prime, t, interval) noise_list.append(noise_pred) return x_prev def get_x_pred(self, x, noise_t, t, interval): a_t = DiffusionUtil.extract(self.ddpm_module.alphas_cumprod, t, x.shape) a_prev = DiffusionUtil.extract(self.ddpm_module.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape) a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) x_pred = x + x_delta return x_pred