from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def extract(v, t, x_shape): """ Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ v = v.to("cuda") out = torch.gather(v, index=t, dim=0).float().to("cuda") return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) class GaussianDiffusionTrainer(nn.Module): def __init__(self, beta_1, beta_T, T, model) -> None: super().__init__() self.model = model self.register_buffer( 'betas', torch.linspace(beta_1,beta_T,T).double() ) self.T = T self.alphas = 1 - self.betas self.beta_alphas = torch.cumprod(self.alphas,dim=0) # Calculation for Algorithm 1 := sqrt(alpha_bar), sqrt(1-alpha_bar) self.register_buffer( "sqrt_beta_alphas", torch.sqrt(self.beta_alphas) ) self.register_buffer( "sqrt_one_minus_beta_alphas", torch.sqrt(1 - self.beta_alphas) ) def forward(self,x_0): t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device) noise = torch.randn_like(x_0) x_t = ( extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 + extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise ) loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean') return loss class GaussianDiffusionSampler(nn.Module): def __init__(self,beta_1,beta_t,model, T) -> None: super().__init__() self.model = model self.T = T self.register_buffer( "betas", torch.linspace(beta_1,beta_t,self.T).double() ) self.alphas = 1 - self.betas self.beta_alphas = torch.cumprod(self.alphas,dim=0) """ This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor. The resulting tensor is stored in self.beta_alphas_prev. """ self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T] self.register_buffer( "coeff1", (1 / torch.sqrt(self.alphas)) ) self.register_buffer( "coeff2", self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas))) ) self.register_buffer( "posterior_coeff", (1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas ) def pred_xt_prev_mean_from_eps(self,x_t,t,eps): return ( extract(self.coeff1,t,x_t.shape) * x_t - extract(self.coeff2,t,x_t.shape) * eps ) def p_mean_variance(self,x_t,t): var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]]) var = extract(var,t,x_t.shape) eps = self.model(x_t,t) xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps) return xt_prev_mean,var def forward(self,x_T): x_t=x_T.to("cuda") for timestep in reversed(range(self.T)): print(f"Sampling timestep: {timestep}") t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep mean, var = self.p_mean_variance(x_t,t) mean , var = mean.to("cuda"), var.to("cuda") if timestep > 0: noise = torch.randn_like(x_t).to("cuda") else: noise = 0 x_t = mean + torch.sqrt(var) * noise x_0 = x_t return torch.clip(x_0,-1,1) class DDIMSampler(nn.Module): def __init__(self,model,n_steps,beta_1,beta_T,ddim_discretize: str = "uniform", ddim_eta = 0.1) -> None: super().__init__() self.steps = n_steps self.model = model if ddim_discretize == "uniform": c = self.steps // n_steps self.time_steps = torch.asarray(list(range(0,self.steps,c))) + 1 print(f"Discreatization uniform : {self.time_steps}") elif ddim_discretize == "quad": self.time_steps = (torch.linspace(0,torch.sqrt(self.steps * .8),n_steps) ** 2).type(torch.int) + 1 print(f"Quad descreatization : {self.time_steps}") else: raise NotImplementedError(ddim_discretize) self.register_buffer( "betas", torch.linspace(beta_1,beta_T,self.steps).double() ) self.alphas = 1 - self.betas self.alpha_bar = torch.cumprod(self.alphas,dim=0) with torch.no_grad(): self.ddim_alpha = self.alpha_bar[self.time_steps].clone().to(torch.float32) self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha) self.ddim_alpha_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[self.time_steps:-1]]) self.ddim_sigma = (ddim_eta * ( (1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) * (1 - self.ddim_alpha / self.ddim_alpha_prev) ) ** .5) self.ddim_sqrt_one_minus_alpha = torch.sqrt(1 - self.ddim_alpha) def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *, uncond_scale: float, uncond_cond: Optional[torch.Tensor]): if uncond_cond is None or uncond_scale == 1.: return self.model(x, t, c) x_in = torch.cat([x] * 2) t_in = torch.cat([t] * 2) c_in = torch.cat([uncond_cond, c]) e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2) e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond) return e_t @torch.no_grad() def forward(self, shape: List[int], cond: torch.Tensor, repeat_noise: bool = False, temperature: float = 1., x_last: Optional[torch.Tensor] = None, uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None, skip_steps: int = 0, ): device = self.model.device batch_size = shape[0] # Get xtS x = x_last if x_last is not None else torch.randn(shape,device=device) time_steps = torch.flip(self.time_steps,[0])[skip_steps:] for i, step in enumerate(reversed(range(time_steps))): index = len(time_steps) - i - 1 ts = x.new_full((batch_size,),step,dtype=torch.long) x, pred_x0, e_t = self.p_sample(x=x,c=cond,t=ts,step=step,index=index, repeat_noise=repeat_noise,temperature=temperature,uncond_scale=uncond_scale, uncond_cond=uncond_cond) # type: ignore return x # Return x0 @torch.no_grad() def p_sample(self, x: torch.Tensor, c:torch.Tensor, t:torch.Tensor, step , index, repeat_noise: bool=False, temperature:float =1., uncond_scale: float = 1., uncond_cond: Optional[torch.Tensor] = None): e_t = self.get_eps(x,t,c,uncond_cond=uncond_cond,uncond_scale=uncond_scale) x_prev ,pred_x0 = self.get_x_prev_and_pred_x0(e_t,index,x, temperature,repeat_noise) return x_prev, pred_x0, e_t def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, temperature: float, repeat_noise: bool): alpha = self.ddim_alpha[index] alpha_prev = self.ddim_alpha_prev[index] sigma = self.ddim_sigma[index] sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index] pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5) dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t if sigma == 0.: noise = torch.zeros_like(x) if repeat_noise: noise = torch.randn((1,*x.shape[1:]), device=x.device) else: noise = torch.randn(x.shape, device=x.device) noise = noise * temperature x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise return x_prev, pred_x0