Diffusion-CIFAR10 / diffusion.py
Yash Nagraj
Add the checkpoint and the generated image
e9c60a4
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def extract(v, t, x_shape):
"""
Extract some coefficients at specified timesteps, then reshape to
[batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
"""
v = v.to("cuda")
out = torch.gather(v, index=t, dim=0).float().to("cuda")
return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
class GaussianDiffusionTrainer(nn.Module):
def __init__(self, beta_1, beta_T, T, model) -> None:
super().__init__()
self.model = model
self.register_buffer(
'betas',
torch.linspace(beta_1,beta_T,T).double()
)
self.T = T
self.alphas = 1 - self.betas
self.beta_alphas = torch.cumprod(self.alphas,dim=0)
# Calculation for Algorithm 1 := sqrt(alpha_bar), sqrt(1-alpha_bar)
self.register_buffer(
"sqrt_beta_alphas",
torch.sqrt(self.beta_alphas)
)
self.register_buffer(
"sqrt_one_minus_beta_alphas",
torch.sqrt(1 - self.beta_alphas)
)
def forward(self,x_0):
t = torch.randint(self.T,size=(x_0.shape[0],),device=x_0.device)
noise = torch.randn_like(x_0)
x_t = (
extract(self.sqrt_beta_alphas,t,x_0.shape) * x_0 +
extract(self.sqrt_one_minus_beta_alphas,t,x_0.shape) * noise
)
loss = F.mse_loss(self.model(x_t,t),noise,reduction='mean')
return loss
class GaussianDiffusionSampler(nn.Module):
def __init__(self,beta_1,beta_t,model, T) -> None:
super().__init__()
self.model = model
self.T = T
self.register_buffer(
"betas",
torch.linspace(beta_1,beta_t,self.T).double()
)
self.alphas = 1 - self.betas
self.beta_alphas = torch.cumprod(self.alphas,dim=0)
"""
This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor.
The resulting tensor is stored in self.beta_alphas_prev.
"""
self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T]
self.register_buffer(
"coeff1",
(1 / torch.sqrt(self.alphas))
)
self.register_buffer(
"coeff2",
self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas)))
)
self.register_buffer(
"posterior_coeff",
(1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas
)
def pred_xt_prev_mean_from_eps(self,x_t,t,eps):
return (
extract(self.coeff1,t,x_t.shape) * x_t -
extract(self.coeff2,t,x_t.shape) * eps
)
def p_mean_variance(self,x_t,t):
var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]])
var = extract(var,t,x_t.shape)
eps = self.model(x_t,t)
xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps)
return xt_prev_mean,var
def forward(self,x_T):
x_t=x_T.to("cuda")
for timestep in reversed(range(self.T)):
print(f"Sampling timestep: {timestep}")
t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep
mean, var = self.p_mean_variance(x_t,t)
mean , var = mean.to("cuda"), var.to("cuda")
if timestep > 0:
noise = torch.randn_like(x_t).to("cuda")
else:
noise = 0
x_t = mean + torch.sqrt(var) * noise
x_0 = x_t
return torch.clip(x_0,-1,1)