import torch import torch.nn as nn import torch.nn.functional as F import numpy as np def extract(v, t, x_shape): """ Extract some coefficients at specified timesteps, then reshape to [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. """ v = v.to("cuda") out = torch.gather(v, index=t, dim=0).float().to("cuda") return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) class GaussianDiffusionTrainer(nn.Module): def __init__(self, beta_1, beta_T, T, model) -> None: super().__init__() self.model = model self.register_buffer( 'betas', torch.linspace(beta_1,beta_T,T).double() ) self.T = T self.alphas = 1 - self.betas self.beta_alphas = torch.cumprod(self.alphas,dim=0) # Calculation for Algorithm 1 := sqrt(alpha_bar), sqrt(1-alpha_bar) self.register_buffer( "sqrt_beta_alphas", torch.sqrt(self.beta_alphas) ) self.register_buffer( "sqrt_one_minus_beta_alphas", torch.sqrt(1 - self.beta_alphas) ) def forward(self,x_0): t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device) noise = torch.randn_like(x_0) x_t = ( extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 + extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise ) loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean') return loss class GaussianDiffusionSampler(nn.Module): def __init__(self,beta_1,beta_t,model, T) -> None: super().__init__() self.model = model self.T = T self.register_buffer( "betas", torch.linspace(beta_1,beta_t,self.T).double() ) self.alphas = 1 - self.betas self.beta_alphas = torch.cumprod(self.alphas,dim=0) """ This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor. The resulting tensor is stored in self.beta_alphas_prev. """ self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T] self.register_buffer( "coeff1", (1 / torch.sqrt(self.alphas)) ) self.register_buffer( "coeff2", self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas))) ) self.register_buffer( "posterior_coeff", (1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas ) def pred_xt_prev_mean_from_eps(self,x_t,t,eps): return ( extract(self.coeff1,t,x_t.shape) * x_t - extract(self.coeff2,t,x_t.shape) * eps ) def p_mean_variance(self,x_t,t): var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]]) var = extract(var,t,x_t.shape) eps = self.model(x_t,t) xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps) return xt_prev_mean,var def forward(self,x_T): x_t=x_T.to("cuda") for timestep in reversed(range(self.T)): print(f"Sampling timestep: {timestep}") t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep mean, var = self.p_mean_variance(x_t,t) mean , var = mean.to("cuda"), var.to("cuda") if timestep > 0: noise = torch.randn_like(x_t).to("cuda") else: noise = 0 x_t = mean + torch.sqrt(var) * noise x_0 = x_t return torch.clip(x_0,-1,1)