| from typing import Union, Callable, Literal, Optional, Tuple |
| from numpy import ndarray |
| from torch import Tensor, device |
|
|
| from tqdm import tqdm |
| import numpy as np |
| import torch |
| from torch import nn |
| import torch.nn.functional as F |
|
|
| from TorchJaekwon.GetModule import GetModule |
| from TorchJaekwon.Util.UtilData import UtilData |
| from TorchJaekwon.Util.UtilTorch import UtilTorch |
| from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil |
| from TorchJaekwon.Model.Diffusion.DDPM.BetaSchedule import BetaSchedule |
|
|
| class DDPM(nn.Module): |
| def __init__(self, |
| model_class_name:Optional[str] = None, |
| model:Optional[nn.Module] = None, |
| |
| model_output_type:Literal['noise', 'x_start', 'v_prediction'] = 'noise', |
| timesteps:int = 1000, |
| |
| loss_func:Union[nn.Module, Callable, Tuple[str,str]] = F.mse_loss, |
| |
| betas: Optional[ndarray] = None, |
| beta_schedule_type:Literal['linear','cosine'] = 'cosine', |
| beta_arg_dict:dict = dict(), |
| |
| unconditional_prob:float = 0, |
| cfg_scale:Optional[float] = None |
| ) -> None: |
| super().__init__() |
| if model_class_name is not None: |
| self.model = GetModule.get_model(model_name = model_class_name) |
| else: |
| self.model:nn.Module = model |
| self.model_output_type:Literal['noise', 'x_start', 'v_prediction'] = model_output_type |
| |
| self.loss_func:Union[nn.Module, Callable] = loss_func |
|
|
| self.timesteps:int = timesteps |
| self.set_noise_schedule(betas=betas, beta_schedule_type=beta_schedule_type, beta_arg_dict=beta_arg_dict, timesteps=timesteps) |
|
|
| self.unconditional_prob:float = unconditional_prob |
| self.cfg_scale:Optional[float] = cfg_scale |
| |
| def set_noise_schedule(self, |
| betas: Optional[ndarray] = None, |
| beta_schedule_type:Literal['linear','cosine'] = 'linear', |
| beta_arg_dict:dict = dict(), |
| timesteps:int = 1000, |
| ) -> None: |
| if betas is None: |
| beta_arg_dict.update({'timesteps':timesteps}) |
| betas = getattr(BetaSchedule,beta_schedule_type)(**beta_arg_dict) |
| |
| alphas:ndarray = 1. - betas |
| alphas_cumprod:ndarray = np.cumprod(alphas, axis=0) |
| alphas_cumprod_prev:ndarray = np.append(1., alphas_cumprod[:-1]) |
|
|
| self.betas:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'betas', value = betas) |
| self.alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod', value = alphas_cumprod) |
| self.alphas_cumprod_prev:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod_prev', value = alphas_cumprod_prev) |
|
|
| |
| self.sqrt_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_alphas_cumprod', value = np.sqrt(alphas_cumprod)) |
| self.sqrt_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_one_minus_alphas_cumprod', value = np.sqrt(1. - alphas_cumprod)) |
| self.log_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'log_one_minus_alphas_cumprod', value = np.log(1. - alphas_cumprod)) |
| self.sqrt_recip_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recip_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod)) |
| self.sqrt_recipm1_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recipm1_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod - 1)) |
|
|
| |
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) |
| |
| self.posterior_variance:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_variance', value = posterior_variance) |
| |
| self.posterior_log_variance_clipped:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_log_variance_clipped', value = np.log(np.maximum(posterior_variance, 1e-20))) |
| self.posterior_mean_coef1:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef1', value = betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) |
| self.posterior_mean_coef2:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef2', value = (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) |
| |
| def forward(self, |
| x_start:Optional[Tensor] = None, |
| x_shape:Optional[tuple] = None, |
| cond:Optional[Union[dict,Tensor]] = None, |
| is_cond_unpack:bool = False, |
| stage: Literal['train', 'infer'] = 'train' |
| ) -> Tensor: |
| ''' |
| train diffusion model. |
| return diffusion loss |
| ''' |
| x_start, cond, additional_data_dict = self.preprocess(x_start, cond) |
| if stage == 'train' and x_start is not None: |
| if x_shape is None: x_shape = x_start.shape |
| batch_size:int = x_shape[0] |
| input_device:device = x_start.device |
| t:Tensor = torch.randint(0, self.timesteps, (batch_size,), device=input_device).long() |
| if DDPM.make_decision(self.unconditional_prob): |
| cond:Optional[Union[dict,Tensor]] = self.get_unconditional_condition(cond=cond, condition_device=input_device) |
| return self.p_losses(x_start, cond, is_cond_unpack, t) |
| else: |
| return self.infer(x_shape = x_shape, cond = cond, is_cond_unpack = is_cond_unpack, additional_data_dict = additional_data_dict) |
| |
| def p_losses(self, |
| x_start:Tensor, |
| cond:Optional[Union[dict,Tensor]], |
| is_cond_unpack:bool, |
| t:Tensor, |
| noise:Optional[Tensor] = None): |
| noise:Tensor = UtilData.default(noise, lambda: torch.randn_like(x_start)) |
| x_noisy:Tensor = self.q_sample(x_start=x_start, t=t, noise=noise) |
| model_output:Tensor = self.apply_model(x_noisy, t, cond, is_cond_unpack) |
|
|
| if self.model_output_type == 'x_start': |
| target:Tensor = x_start |
| elif self.model_output_type == 'noise': |
| target:Tensor = noise |
| elif self.model_output_type == 'v_prediction': |
| target:Tensor = self.get_v(x_start, noise, t) |
| else: |
| print(f'''model output type is {self.model_output_type}. It should be in [x_start, noise]''') |
| raise NotImplementedError() |
| if target.shape != model_output.shape: print(f'warning: target shape({target.shape}) and model shape({model_output.shape}) are different') |
| return self.loss_func(target, model_output) |
| |
| def get_v(self, x, noise, t): |
| ''' |
| Progressive Distillation for Fast Sampling of Diffusion Models |
| https://arxiv.org/abs/2202.00512 |
| ''' |
| return ( |
| DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x.shape) * noise |
| - DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x |
| ) |
| |
| def q_sample(self, x_start:Tensor, t:Tensor, noise=None) -> Tensor: |
| ''' |
| noisy x sample for forward process |
| ''' |
| noise = UtilData.default(noise, lambda: torch.randn_like(x_start)) |
| return ( |
| DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + |
| DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise |
| ) |
| |
| def q_mean_variance(self, x_start, t): |
| """ |
| Get the distribution q(x_t | x_0). |
| :param x_start: the [N x C x ...] tensor of noiseless inputs. |
| :param t: the number of diffusion steps (minus 1). Here, 0 means one step. |
| :return: A tuple (mean, variance, log_variance), all of x_start's shape. |
| """ |
| mean = DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start |
| variance = DiffusionUtil.extract(1.0 - self.alphas_cumprod, t, x_start.shape) |
| log_variance = DiffusionUtil.extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) |
| return mean, variance, log_variance |
| |
| @torch.no_grad() |
| def infer(self, |
| x_shape:tuple, |
| cond:Optional[Union[dict,Tensor]], |
| is_cond_unpack:bool, |
| additional_data_dict:dict): |
| if x_shape is None: x_shape = self.get_x_shape(cond) |
| model_device:device = UtilTorch.get_model_device(self.model) |
| x:Tensor = torch.randn(x_shape, device = model_device) |
| for i in tqdm(reversed(range(0, self.timesteps)), desc='sample time step', total=self.timesteps): |
| x = self.p_sample(x = x, t = torch.full((x_shape[0],), i, device= model_device, dtype=torch.long), cond = cond, is_cond_unpack = is_cond_unpack) |
| |
| return self.postprocess(x, additional_data_dict = additional_data_dict) |
| |
| @torch.no_grad() |
| def p_sample(self, |
| x:Tensor, |
| t:Tensor, |
| cond:Optional[Union[dict,Tensor]], |
| is_cond_unpack:bool, |
| clip_denoised:bool = False, |
| repeat_noise:bool = False): |
| b, *_, device = *x.shape, x.device |
| model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, cond = cond, is_cond_unpack = is_cond_unpack, clip_denoised = clip_denoised) |
| noise = DiffusionUtil.noise_like(x.shape, device, repeat_noise) |
| |
| nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) |
| return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise |
| |
| def p_mean_variance(self, |
| x:Tensor, |
| t:Tensor, |
| cond:Optional[Union[dict,Tensor]], |
| is_cond_unpack:bool, |
| clip_denoised: bool) -> Tuple[Tensor]: |
| |
| model_output:Tensor = self.apply_model(x, t, cond, is_cond_unpack, cfg_scale=self.cfg_scale) |
| if self.model_output_type == "noise": |
| x_recon = self.predict_x_start_from_noise(x, t=t, noise=model_output) |
| elif self.model_output_type == 'x_start': |
| x_recon = model_output |
| elif self.model_output_type == 'v_prediction': |
| x_recon = self.predict_x_start_from_v(x, t=t, v=model_output) |
|
|
| if clip_denoised: |
| x_recon.clamp_(-1., 1.) |
|
|
| model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) |
| return model_mean, posterior_variance, posterior_log_variance |
| |
| def predict_x_start_from_noise(self, x_t, t, noise): |
| return ( |
| DiffusionUtil.extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - |
| DiffusionUtil.extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise |
| ) |
| |
| def predict_x_start_from_v(self, x_t, t, v): |
| |
| |
| return ( |
| DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t |
| - DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v |
| ) |
|
|
| def predict_noise_from_v(self, x_t, t, v): |
| return ( |
| DiffusionUtil.extract(self.sqrt_alphas_cumprod, t, x_t.shape) * v |
| + DiffusionUtil.extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) |
| * x_t |
| ) |
| |
| def q_posterior(self, x_start, x_t, t): |
| posterior_mean = ( |
| DiffusionUtil.extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + |
| DiffusionUtil.extract(self.posterior_mean_coef2, t, x_t.shape) * x_t |
| ) |
| posterior_variance = DiffusionUtil.extract(self.posterior_variance, t, x_t.shape) |
| posterior_log_variance_clipped = DiffusionUtil.extract(self.posterior_log_variance_clipped, t, x_t.shape) |
| return posterior_mean, posterior_variance, posterior_log_variance_clipped |
| |
| def preprocess(self, x_start:Tensor, cond:Optional[Union[dict,Tensor]] = None) -> Tuple[Tensor, Optional[Union[dict,Tensor]], dict]: |
| return x_start, cond, None |
|
|
| def postprocess(self, x:Tensor, additional_data_dict:dict) -> Tensor: |
| return x |
|
|
| def apply_model(self, |
| x:Tensor, |
| t:Tensor, |
| cond:Optional[Union[dict,Tensor]], |
| is_cond_unpack:bool, |
| cfg_scale:Optional[float] = None |
| ) -> Tensor: |
| if cfg_scale is None or cfg_scale == 1.0: |
| if cond is None: |
| return self.model(x, t) |
| elif is_cond_unpack: |
| return self.model(x, t, **cond) |
| else: |
| return self.model(x, t, cond) |
| else: |
| model_conditioned_output = self.model(x, t, **cond) if is_cond_unpack else self.model(x, t, cond) |
| unconditional_conditioning = self.get_unconditional_condition(cond=cond) |
| model_unconditioned_output = self.model(x, t, **unconditional_conditioning) if is_cond_unpack else self.model(x, t, unconditional_conditioning) |
| return model_unconditioned_output + cfg_scale * (model_conditioned_output - model_unconditioned_output) |
| |
| @staticmethod |
| def make_decision(probability:float |
| ) -> bool: |
| if probability == 0: |
| return False |
| if float(torch.rand(1)) < probability: |
| return True |
| else: |
| return False |
| |
| def get_unconditional_condition(self, |
| cond:Optional[Union[dict,Tensor]] = None, |
| cond_shape:Optional[tuple] = None, |
| condition_device:Optional[device] = None |
| ) -> Tensor: |
| print('Default Unconditional Condition. You might wanna overwrite this function') |
| if cond_shape is None: cond_shape = cond.shape |
| if cond is not None and isinstance(cond,Tensor): condition_device = cond.device |
| return (-11.4981 + torch.zeros(cond_shape)).to(condition_device) |
| |
| def get_x_shape(self, cond:Optional[Union[dict,Tensor]] = None): |
| return None |
|
|
| if __name__ == '__main__': |
| DDPM(model = 'debug') |
|
|
| |
| |
| |
|
|