ControlNet / scheduler /linear_scheduler.py
YashNagraj75's picture
Add ControlNet and Scheduler
d62b4c3
import torch
class LinearNoiseScheduler:
r"""
Class for the linear noise scheduler that is used in DDPM.
"""
def __init__(self, num_timesteps, beta_start, beta_end, ldm_scheduler=False):
self.num_timesteps = num_timesteps
self.beta_start = beta_start
self.beta_end = beta_end
if ldm_scheduler:
# Mimicking how compvis repo creates schedule
self.betas = (
torch.linspace(beta_start**0.5, beta_end**0.5, num_timesteps) ** 2
)
else:
self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
self.alphas = 1.0 - self.betas
self.alpha_cum_prod = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)
def add_noise(self, original, noise, t):
r"""
Forward method for diffusion
:param original: Image on which noise is to be applied
:param noise: Random Noise Tensor (from normal dist)
:param t: timestep of the forward process of shape -> (B,)
:return:
"""
original_shape = original.shape
batch_size = original_shape[0]
sqrt_alpha_cum_prod = self.sqrt_alpha_cum_prod.to(original.device)[t].reshape(
batch_size
)
sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod.to(
original.device
)[t].reshape(batch_size)
# Reshape till (B,) becomes (B,1,1,1) if image is (B,C,H,W)
for _ in range(len(original_shape) - 1):
sqrt_alpha_cum_prod = sqrt_alpha_cum_prod.unsqueeze(-1)
for _ in range(len(original_shape) - 1):
sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)
# Apply and Return Forward process equation
return (
sqrt_alpha_cum_prod.to(original.device) * original
+ sqrt_one_minus_alpha_cum_prod.to(original.device) * noise
)
def sample_prev_timestep(self, xt, noise_pred, t):
r"""
Use the noise prediction by model to get
xt-1 using xt and the noise predicted
:param xt: current timestep sample
:param noise_pred: model noise prediction
:param t: current timestep we are at
:return:
"""
x0 = (
xt - (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t] * noise_pred)
) / torch.sqrt(self.alpha_cum_prod.to(xt.device)[t])
x0 = torch.clamp(x0, -1.0, 1.0)
mean = (
xt
- ((self.betas.to(xt.device)[t]) * noise_pred)
/ (self.sqrt_one_minus_alpha_cum_prod.to(xt.device)[t])
)
mean = mean / torch.sqrt(self.alphas.to(xt.device)[t])
if t == 0:
return mean, x0
else:
variance = (1 - self.alpha_cum_prod.to(xt.device)[t - 1]) / (
1.0 - self.alpha_cum_prod.to(xt.device)[t]
)
variance = variance * self.betas.to(xt.device)[t]
sigma = variance**0.5
z = torch.randn(xt.shape).to(xt.device)
# OR
# variance = self.betas[t]
# sigma = variance ** 0.5
# z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0