| from typing import List, Optional |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
|
|
|
|
| def extract(v, t, x_shape): |
| """ |
| Extract some coefficients at specified timesteps, then reshape to |
| [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes. |
| """ |
| v = v.to("cuda") |
| out = torch.gather(v, index=t, dim=0).float().to("cuda") |
| return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) |
|
|
|
|
| class GaussianDiffusionTrainer(nn.Module): |
| def __init__(self, beta_1, beta_T, T, model) -> None: |
| super().__init__() |
| self.model = model |
| self.register_buffer( |
| 'betas', |
| torch.linspace(beta_1,beta_T,T).double() |
| ) |
|
|
| self.T = T |
| self.alphas = 1 - self.betas |
| self.beta_alphas = torch.cumprod(self.alphas,dim=0) |
|
|
| |
| self.register_buffer( |
| "sqrt_beta_alphas", |
| torch.sqrt(self.beta_alphas) |
| ) |
| self.register_buffer( |
| "sqrt_one_minus_beta_alphas", |
| torch.sqrt(1 - self.beta_alphas) |
| ) |
| |
| def forward(self,x_0): |
| t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device) |
| noise = torch.randn_like(x_0) |
| x_t = ( |
| extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 + |
| extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise |
| ) |
| loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean') |
| return loss |
| |
|
|
| class GaussianDiffusionSampler(nn.Module): |
| def __init__(self,beta_1,beta_t,model, T) -> None: |
| super().__init__() |
| self.model = model |
| self.T = T |
| self.register_buffer( |
| "betas", |
| torch.linspace(beta_1,beta_t,self.T).double() |
| ) |
| self.alphas = 1 - self.betas |
| self.beta_alphas = torch.cumprod(self.alphas,dim=0) |
|
|
| """ |
| This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor. |
| The resulting tensor is stored in self.beta_alphas_prev. |
| """ |
| self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T] |
|
|
| self.register_buffer( |
| "coeff1", |
| (1 / torch.sqrt(self.alphas)) |
| ) |
|
|
| self.register_buffer( |
| "coeff2", |
| self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas))) |
| ) |
|
|
| self.register_buffer( |
| "posterior_coeff", |
| (1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas |
| ) |
|
|
| def pred_xt_prev_mean_from_eps(self,x_t,t,eps): |
| return ( |
| extract(self.coeff1,t,x_t.shape) * x_t - |
| extract(self.coeff2,t,x_t.shape) * eps |
| ) |
| |
| def p_mean_variance(self,x_t,t): |
| var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]]) |
| var = extract(var,t,x_t.shape) |
|
|
| eps = self.model(x_t,t) |
| xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps) |
| return xt_prev_mean,var |
| |
| def forward(self,x_T): |
| x_t=x_T.to("cuda") |
| for timestep in reversed(range(self.T)): |
| print(f"Sampling timestep: {timestep}") |
|
|
| t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep |
| mean, var = self.p_mean_variance(x_t,t) |
| mean , var = mean.to("cuda"), var.to("cuda") |
| if timestep > 0: |
| noise = torch.randn_like(x_t).to("cuda") |
| else: |
| noise = 0 |
| x_t = mean + torch.sqrt(var) * noise |
| |
| x_0 = x_t |
| return torch.clip(x_0,-1,1) |
|
|
|
|
| class DDIMSampler(nn.Module): |
| def __init__(self,model,n_steps,beta_1,beta_T,ddim_discretize: str = "uniform", ddim_eta = 0.1) -> None: |
| super().__init__() |
| self.steps = n_steps |
| self.model = model |
| |
| if ddim_discretize == "uniform": |
| c = self.steps // n_steps |
| self.time_steps = torch.asarray(list(range(0,self.steps,c))) + 1 |
| print(f"Discreatization uniform : {self.time_steps}") |
|
|
| elif ddim_discretize == "quad": |
| self.time_steps = (torch.linspace(0,torch.sqrt(self.steps * .8),n_steps) ** 2).type(torch.int) + 1 |
| print(f"Quad descreatization : {self.time_steps}") |
| else: |
| raise NotImplementedError(ddim_discretize) |
| |
| |
| self.register_buffer( |
| "betas", |
| torch.linspace(beta_1,beta_T,self.steps).double() |
| ) |
|
|
| self.alphas = 1 - self.betas |
| self.alpha_bar = torch.cumprod(self.alphas,dim=0) |
|
|
| with torch.no_grad(): |
| self.ddim_alpha = self.alpha_bar[self.time_steps].clone().to(torch.float32) |
| self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha) |
| self.ddim_alpha_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[self.time_steps:-1]]) |
| self.ddim_sigma = (ddim_eta * ( |
| (1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) * |
| (1 - self.ddim_alpha / self.ddim_alpha_prev) |
| ) ** .5) |
| self.ddim_sqrt_one_minus_alpha = torch.sqrt(1 - self.ddim_alpha) |
|
|
| def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *, |
| uncond_scale: float, uncond_cond: Optional[torch.Tensor]): |
| if uncond_cond is None or uncond_scale == 1.: |
| return self.model(x, t, c) |
|
|
| x_in = torch.cat([x] * 2) |
| t_in = torch.cat([t] * 2) |
|
|
| c_in = torch.cat([uncond_cond, c]) |
|
|
| e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2) |
|
|
| e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond) |
|
|
| return e_t |
|
|
| @torch.no_grad() |
| def forward(self, |
| shape: List[int], |
| cond: torch.Tensor, |
| repeat_noise: bool = False, |
| temperature: float = 1., |
| x_last: Optional[torch.Tensor] = None, |
| uncond_scale: float = 1., |
| uncond_cond: Optional[torch.Tensor] = None, |
| skip_steps: int = 0, |
| ): |
|
|
| device = self.model.device |
| batch_size = shape[0] |
|
|
| |
| x = x_last if x_last is not None else torch.randn(shape,device=device) |
| time_steps = torch.flip(self.time_steps,[0])[skip_steps:] |
| for i, step in enumerate(reversed(range(time_steps))): |
| index = len(time_steps) - i - 1 |
| ts = x.new_full((batch_size,),step,dtype=torch.long) |
| x, pred_x0, e_t = self.p_sample(x=x,c=cond,t=ts,step=step,index=index, |
| repeat_noise=repeat_noise,temperature=temperature,uncond_scale=uncond_scale, |
| uncond_cond=uncond_cond) |
| return x |
| |
| @torch.no_grad() |
| def p_sample(self, x: torch.Tensor, c:torch.Tensor, t:torch.Tensor, step , index, |
| repeat_noise: bool=False, temperature:float =1., |
| uncond_scale: float = 1., |
| uncond_cond: Optional[torch.Tensor] = None): |
| e_t = self.get_eps(x,t,c,uncond_cond=uncond_cond,uncond_scale=uncond_scale) |
|
|
| x_prev ,pred_x0 = self.get_x_prev_and_pred_x0(e_t,index,x, |
| temperature,repeat_noise) |
| |
| return x_prev, pred_x0, e_t |
| |
| def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, temperature: float, repeat_noise: bool): |
| alpha = self.ddim_alpha[index] |
| alpha_prev = self.ddim_alpha_prev[index] |
| sigma = self.ddim_sigma[index] |
| sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index] |
|
|
| pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5) |
| dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t |
| if sigma == 0.: |
| noise = torch.zeros_like(x) |
|
|
| if repeat_noise: |
| noise = torch.randn((1,*x.shape[1:]), device=x.device) |
|
|
| else: |
| noise = torch.randn(x.shape, device=x.device) |
|
|
| noise = noise * temperature |
|
|
| x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise |
|
|
| return x_prev, pred_x0 |