FlashSR_One-step_Versatile_Audio_Super-resolution / TorchJaekwon /Model /Diffusion /DDPM /DDPMLossVLB.py
| from typing import Union, Callable, Literal, Optional, Tuple | |
| from numpy import ndarray | |
| from torch import Tensor, device | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from TorchJaekwon.GetModule import GetModule | |
| from TorchJaekwon.Util.UtilData import UtilData | |
| from TorchJaekwon.Util.UtilTorch import UtilTorch | |
| from TorchJaekwon.Model.Diffusion.DDPM.DiffusionUtil import DiffusionUtil | |
| from TorchJaekwon.Model.Diffusion.DDPM.BetaSchedule import BetaSchedule | |
| from TorchJaekwon.Model.Diffusion.DDPM.DDPM import DDPM | |
| class DDPMLossVLB(DDPM): | |
| def __init__(self, | |
| use_vlb_loss:bool = True, | |
| loss_simple_weight:float=1.0, | |
| original_elbo_weight:float=0.0, | |
| logvar_init:float=0.0, | |
| learn_logvar:bool=False, | |
| *args, | |
| **kwargs): | |
| self.use_vlb_loss = use_vlb_loss | |
| super().__init__(*args, **kwargs) | |
| self.loss_simple_weight:float = loss_simple_weight | |
| self.original_elbo_weight:float = original_elbo_weight | |
| self.logvar:float = torch.full(fill_value=logvar_init, size=(self.timesteps,)) | |
| self.learn_logvar:bool = learn_logvar | |
| if self.learn_logvar: | |
| self.logvar = nn.Parameter(self.logvar, requires_grad=True) | |
| else: | |
| self.logvar = nn.Parameter(self.logvar, requires_grad=False) | |
| def set_noise_schedule(self, | |
| betas: Optional[ndarray] = None, | |
| beta_schedule_type:Literal['linear','cosine'] = 'linear', | |
| beta_arg_dict:dict = dict(), | |
| timesteps:int = 1000, | |
| ) -> None: | |
| if betas is None: | |
| beta_arg_dict.update({'timesteps':timesteps}) | |
| betas = getattr(BetaSchedule,beta_schedule_type)(**beta_arg_dict) | |
| alphas:ndarray = 1. - betas | |
| alphas_cumprod:ndarray = np.cumprod(alphas, axis=0) | |
| alphas_cumprod_prev:ndarray = np.append(1., alphas_cumprod[:-1]) | |
| self.betas:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'betas', value = betas) | |
| self.alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod', value = alphas_cumprod) | |
| self.alphas_cumprod_prev:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'alphas_cumprod_prev', value = alphas_cumprod_prev) | |
| # calculations for diffusion q(x_t | x_{t-1}) and others | |
| self.sqrt_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_alphas_cumprod', value = np.sqrt(alphas_cumprod)) | |
| self.sqrt_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_one_minus_alphas_cumprod', value = np.sqrt(1. - alphas_cumprod)) | |
| self.log_one_minus_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'log_one_minus_alphas_cumprod', value = np.log(1. - alphas_cumprod)) | |
| self.sqrt_recip_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recip_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod)) | |
| self.sqrt_recipm1_alphas_cumprod:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'sqrt_recipm1_alphas_cumprod', value = np.sqrt(1. / alphas_cumprod - 1)) | |
| # calculations for posterior q(x_{t-1} | x_t, x_0) | |
| posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) | |
| # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) | |
| self.posterior_variance:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_variance', value = posterior_variance) | |
| # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain | |
| self.posterior_log_variance_clipped:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_log_variance_clipped', value = np.log(np.maximum(posterior_variance, 1e-20))) | |
| self.posterior_mean_coef1:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef1', value = betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) | |
| self.posterior_mean_coef2:Tensor = UtilTorch.register_buffer(model = self, variable_name = 'posterior_mean_coef2', value = (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)) | |
| if self.use_vlb_loss: | |
| if self.model_output_type == 'noise': | |
| lvlb_weights = self.betas**2 / ( | |
| 2 | |
| * self.posterior_variance | |
| * torch.tensor(alphas, dtype=torch.float32) | |
| * (1 - self.alphas_cumprod) | |
| ) | |
| elif self.model_output_type == 'x_start': | |
| lvlb_weights = ( | |
| 0.5 | |
| * np.sqrt(torch.Tensor(alphas_cumprod)) | |
| / (2.0 * 1 - torch.Tensor(alphas_cumprod)) | |
| ) | |
| elif self.model_output_type == 'v_prediction': | |
| lvlb_weights = torch.ones_like( | |
| self.betas**2 | |
| / ( | |
| 2 | |
| * self.posterior_variance | |
| * torch.tensor(alphas, dtype=torch.float32) | |
| * (1 - self.alphas_cumprod) | |
| ) | |
| ) | |
| else: | |
| raise NotImplementedError("mu not supported") | |
| # TODO how to choose this term | |
| lvlb_weights[0] = lvlb_weights[1] | |
| self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) | |
| self.lvlb_weights = self.lvlb_weights | |
| assert not torch.isnan(self.lvlb_weights).all() | |
| def p_losses(self, | |
| x_start:Tensor, | |
| cond:Optional[Union[dict,Tensor]], | |
| is_cond_unpack:bool, | |
| t:Tensor, | |
| noise:Optional[Tensor] = None): | |
| if not self.use_vlb_loss: | |
| return super().p_losses(x_start, cond, is_cond_unpack, t, noise) | |
| noise:Tensor = UtilData.default(noise, lambda: torch.randn_like(x_start)) | |
| x_noisy:Tensor = self.q_sample(x_start=x_start, t=t, noise=noise) | |
| model_output:Tensor = self.apply_model(x_noisy, t, cond, is_cond_unpack) | |
| if self.model_output_type == 'x_start': | |
| target:Tensor = x_start | |
| elif self.model_output_type == 'noise': | |
| target:Tensor = noise | |
| elif self.model_output_type == 'v_prediction': | |
| target:Tensor = self.get_v(x_start, noise, t) | |
| else: | |
| print(f'''model output type is {self.model_output_type}. It should be in [x_start, noise]''') | |
| raise NotImplementedError() | |
| if target.shape != model_output.shape: print(f'warning: target shape({target.shape}) and model shape({model_output.shape}) are different') | |
| loss_dict = dict() | |
| loss_simple:Tensor = self.get_loss(model_output, target, mean=False) | |
| loss_simple = loss_simple.mean(dim = list(range(len(loss_simple.shape)))[1:]) | |
| loss_dict.update({f"loss_simple": loss_simple.mean()}) | |
| logvar_t = self.logvar[t] | |
| loss = loss_simple / torch.exp(logvar_t) + logvar_t | |
| if self.learn_logvar: | |
| loss_dict.update({f"loss_gamma": loss.mean()}) | |
| loss_dict.update({"logvar": self.logvar.data.mean()}) | |
| loss = self.loss_simple_weight * loss.mean() | |
| loss_vlb:Tensor = self.get_loss(model_output, target, mean=False) | |
| loss_vlb = loss_vlb.mean(dim=list(range(len(loss_vlb.shape)))[1:]) | |
| loss_vlb = (self.lvlb_weights[t] * loss_vlb).mean() | |
| loss_dict.update({f"loss_vlb": loss_vlb}) | |
| loss += self.original_elbo_weight * loss_vlb | |
| loss_dict.update({f"loss": loss}) | |
| return loss_dict | |
| def get_loss(self, pred:Tensor, target:Tensor, mean=True) -> Tensor: | |
| if self.loss_func == 'l1': | |
| loss = (target - pred).abs() | |
| if mean: | |
| loss = loss.mean() | |
| elif self.loss_func == F.mse_loss: | |
| if mean: | |
| loss = self.loss_func(target, pred) | |
| else: | |
| loss = self.loss_func(target, pred, reduction='none') | |
| else: | |
| raise NotImplementedError("unknown loss type '{loss_type}'") | |
| return loss | |
| if __name__ == '__main__': | |
| ddpm = DDPMLossVLB(model = lambda x, t: x, model_output_type = 'v_prediction') | |
| ddpm.p_losses(x_start = torch.randn(2,3,64,64), cond = None, is_cond_unpack = False, t = torch.tensor([30, 23])) | |
| print('finish') |