Diffusion-DDIM / diffusion.py
Yash Nagraj
Add p_sample function
3c19795
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
v = v.to("cuda")
out = torch.gather(v, index=t, dim=0).float().to("cuda")
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, beta_1, beta_T, T, model) -> None:
super().__init__()
self.model = model
self.register_buffer(
'betas',
torch.linspace(beta_1,beta_T,T).double()
)
self.T = T
self.alphas = 1 - self.betas
self.beta_alphas = torch.cumprod(self.alphas,dim=0)
# Calculation for Algorithm 1 := sqrt(alpha_bar), sqrt(1-alpha_bar)
self.register_buffer(
"sqrt_beta_alphas",
torch.sqrt(self.beta_alphas)
)
self.register_buffer(
"sqrt_one_minus_beta_alphas",
torch.sqrt(1 - self.beta_alphas)
)
def forward(self,x_0):
t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 +
extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise
)
loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean')
return loss
class GaussianDiffusionSampler(nn.Module):
def __init__(self,beta_1,beta_t,model, T) -> None:
super().__init__()
self.model = model
self.T = T
self.register_buffer(
"betas",
torch.linspace(beta_1,beta_t,self.T).double()
)
self.alphas = 1 - self.betas
self.beta_alphas = torch.cumprod(self.alphas,dim=0)
"""
This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor.
The resulting tensor is stored in self.beta_alphas_prev.
"""
self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T]
self.register_buffer(
"coeff1",
(1 / torch.sqrt(self.alphas))
)
self.register_buffer(
"coeff2",
self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas)))
)
self.register_buffer(
"posterior_coeff",
(1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas
)
def pred_xt_prev_mean_from_eps(self,x_t,t,eps):
return (
extract(self.coeff1,t,x_t.shape) * x_t -
extract(self.coeff2,t,x_t.shape) * eps
)
def p_mean_variance(self,x_t,t):
var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]])
var = extract(var,t,x_t.shape)
eps = self.model(x_t,t)
xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps)
return xt_prev_mean,var
def forward(self,x_T):
x_t=x_T.to("cuda")
for timestep in reversed(range(self.T)):
print(f"Sampling timestep: {timestep}")
t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep
mean, var = self.p_mean_variance(x_t,t)
mean , var = mean.to("cuda"), var.to("cuda")
if timestep > 0:
noise = torch.randn_like(x_t).to("cuda")
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
x_0 = x_t
return torch.clip(x_0,-1,1)
class DDIMSampler(nn.Module):
def __init__(self,model,n_steps,beta_1,beta_T,ddim_discretize: str = "uniform", ddim_eta = 0.1) -> None:
super().__init__()
self.steps = n_steps
self.model = model
if ddim_discretize == "uniform":
c = self.steps // n_steps
self.time_steps = torch.asarray(list(range(0,self.steps,c))) + 1
print(f"Discreatization uniform : {self.time_steps}")
elif ddim_discretize == "quad":
self.time_steps = (torch.linspace(0,torch.sqrt(self.steps * .8),n_steps) ** 2).type(torch.int) + 1
print(f"Quad descreatization : {self.time_steps}")
else:
raise NotImplementedError(ddim_discretize)
self.register_buffer(
"betas",
torch.linspace(beta_1,beta_T,self.steps).double()
)
self.alphas = 1 - self.betas
self.alpha_bar = torch.cumprod(self.alphas,dim=0)
with torch.no_grad():
self.ddim_alpha = self.alpha_bar[self.time_steps].clone().to(torch.float32)
self.ddim_alpha_sqrt = torch.sqrt(self.ddim_alpha)
self.ddim_alpha_prev = torch.cat([self.alpha_bar[0:1], self.alpha_bar[self.time_steps:-1]])
self.ddim_sigma = (ddim_eta * (
(1 - self.ddim_alpha_prev) / (1 - self.ddim_alpha) *
(1 - self.ddim_alpha / self.ddim_alpha_prev)
) ** .5)
self.ddim_sqrt_one_minus_alpha = torch.sqrt(1 - self.ddim_alpha)
def get_eps(self, x: torch.Tensor, t: torch.Tensor, c: torch.Tensor, *,
uncond_scale: float, uncond_cond: Optional[torch.Tensor]):
if uncond_cond is None or uncond_scale == 1.:
return self.model(x, t, c)
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([uncond_cond, c])
e_t_uncond, e_t_cond = self.model(x_in, t_in, c_in).chunk(2)
e_t = e_t_uncond + uncond_scale * (e_t_cond - e_t_uncond)
return e_t
@torch.no_grad()
def forward(self,
shape: List[int],
cond: torch.Tensor,
repeat_noise: bool = False,
temperature: float = 1.,
x_last: Optional[torch.Tensor] = None,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None,
skip_steps: int = 0,
):
device = self.model.device
batch_size = shape[0]
# Get xtS
x = x_last if x_last is not None else torch.randn(shape,device=device)
time_steps = torch.flip(self.time_steps,[0])[skip_steps:]
for i, step in enumerate(reversed(range(time_steps))):
index = len(time_steps) - i - 1
ts = x.new_full((batch_size,),step,dtype=torch.long)
x, pred_x0, e_t = self.p_sample(x=x,c=cond,t=ts,step=step,index=index,
repeat_noise=repeat_noise,temperature=temperature,uncond_scale=uncond_scale,
uncond_cond=uncond_cond) # type: ignore
return x # Return x0
@torch.no_grad()
def p_sample(self, x: torch.Tensor, c:torch.Tensor, t:torch.Tensor, step , index,
repeat_noise: bool=False, temperature:float =1.,
uncond_scale: float = 1.,
uncond_cond: Optional[torch.Tensor] = None):
e_t = self.get_eps(x,t,c,uncond_cond=uncond_cond,uncond_scale=uncond_scale)
x_prev ,pred_x0 = self.get_x_prev_and_pred_x0(e_t,index,x,
temperature,repeat_noise)
return x_prev, pred_x0, e_t
def get_x_prev_and_pred_x0(self, e_t: torch.Tensor, index: int, x: torch.Tensor, temperature: float, repeat_noise: bool):
alpha = self.ddim_alpha[index]
alpha_prev = self.ddim_alpha_prev[index]
sigma = self.ddim_sigma[index]
sqrt_one_minus_alpha = self.ddim_sqrt_one_minus_alpha[index]
pred_x0 = (x - sqrt_one_minus_alpha * e_t) / (alpha ** 0.5)
dir_xt = (1. - alpha_prev - sigma ** 2).sqrt() * e_t
if sigma == 0.:
noise = torch.zeros_like(x)
if repeat_noise:
noise = torch.randn((1,*x.shape[1:]), device=x.device)
else:
noise = torch.randn(x.shape, device=x.device)
noise = noise * temperature
x_prev = (alpha_prev ** 0.5) * pred_x0 + dir_xt + sigma * noise
return x_prev, pred_x0