| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| def extract(v, t, x_shape): |
| """ |
| Extract some coefficients at specified timesteps, then reshape to |
| [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. |
| """ |
| v = v.to("cuda") |
| out = torch.gather(v, index=t, dim=0).float().to("cuda") |
| return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) |
|
|
|
|
| class GaussianDiffusionTrainer(nn.Module): |
| def __init__(self, beta_1, beta_T, T, model) -> None: |
| super().__init__() |
| self.model = model |
| self.register_buffer( |
| 'betas', |
| torch.linspace(beta_1,beta_T,T).double() |
| ) |
|
|
| self.T = T |
| self.alphas = 1 - self.betas |
| self.beta_alphas = torch.cumprod(self.alphas,dim=0) |
|
|
| |
| self.register_buffer( |
| "sqrt_beta_alphas", |
| torch.sqrt(self.beta_alphas) |
| ) |
| self.register_buffer( |
| "sqrt_one_minus_beta_alphas", |
| torch.sqrt(1 - self.beta_alphas) |
| ) |
| |
| def forward(self,x_0): |
| t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device) |
| noise = torch.randn_like(x_0) |
| x_t = ( |
| extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 + |
| extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise |
| ) |
| loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean') |
| return loss |
| |
|
|
| class GaussianDiffusionSampler(nn.Module): |
| def __init__(self,beta_1,beta_t,model, T) -> None: |
| super().__init__() |
| self.model = model |
| self.T = T |
| self.register_buffer( |
| "betas", |
| torch.linspace(beta_1,beta_t,self.T).double() |
| ) |
| self.alphas = 1 - self.betas |
| self.beta_alphas = torch.cumprod(self.alphas,dim=0) |
|
|
| """ |
| This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor. |
| The resulting tensor is stored in self.beta_alphas_prev. |
| """ |
| self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T] |
|
|
| self.register_buffer( |
| "coeff1", |
| (1 / torch.sqrt(self.alphas)) |
| ) |
|
|
| self.register_buffer( |
| "coeff2", |
| self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas))) |
| ) |
|
|
| self.register_buffer( |
| "posterior_coeff", |
| (1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas |
| ) |
|
|
| def pred_xt_prev_mean_from_eps(self,x_t,t,eps): |
| return ( |
| extract(self.coeff1,t,x_t.shape) * x_t - |
| extract(self.coeff2,t,x_t.shape) * eps |
| ) |
| |
| def p_mean_variance(self,x_t,t): |
| var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]]) |
| var = extract(var,t,x_t.shape) |
|
|
| eps = self.model(x_t,t) |
| xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps) |
| return xt_prev_mean,var |
| |
| def forward(self,x_T): |
| x_t=x_T.to("cuda") |
| for timestep in reversed(range(self.T)): |
| print(f"Sampling timestep: {timestep}") |
|
|
| t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep |
| mean, var = self.p_mean_variance(x_t,t) |
| mean , var = mean.to("cuda"), var.to("cuda") |
| if timestep > 0: |
| noise = torch.randn_like(x_t).to("cuda") |
| else: |
| noise = 0 |
| x_t = mean + torch.sqrt(var) * noise |
| |
| x_0 = x_t |
| return torch.clip(x_0,-1,1) |
|
|
|
|