cat_diffusion / noise_scheduler.py
detectivejoewest's picture
Update noise_scheduler.py
b384f6d verified
import numpy as np
import torch
class NoiseSchedule:
"""
Handles:
- DDIM inference (with a ddim_mod to skip steps)
- DDPM inference
- Forward Noising
- Linear beta schedule
- Classifier Free Guidance (w is a hyperparameter for cfg schedule)
"""
def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
self.T = T
self.std = std
self.ddim_mod = ddim_mod
self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
self.alpha = 1 - self.beta
self.alpha_bar = self.alpha.cumprod(dim=0)
self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
self.shape = shape
def noise(self, x, t):
eps = torch.randn_like(x) * self.std
return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps
def ddim_step(self, xt, t, eps):
x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
x0 = x0.clamp(-1, 1)
# note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
return xt_1
def ddpm_step(self, x, eps, t, var=None):
var = self.beta[t] if var is None else var
return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)
def generate(self, model, num_images=16, device="cuda"):
with torch.no_grad():
x = torch.randn((num_images, *self.shape), device=device) * self.std
for t in range(self.T-1, -1, -self.ddim_mod):
t_tensor = torch.full((num_images,),t, device=device)
epsilons = model(x, t=t_tensor)
x = self.ddim_step(x, t=t, eps=epsilons)
return x