| | import numpy as np |
| | import torch |
| |
|
| | class NoiseSchedule: |
| | """ |
| | Handles: |
| | - DDIM inference (with a ddim_mod to skip steps) |
| | - DDPM inference |
| | - Forward Noising |
| | - Linear beta schedule |
| | - Classifier Free Guidance (w is a hyperparameter for cfg schedule) |
| | """ |
| | def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False): |
| | self.T = T |
| | self.std = std |
| | self.ddim_mod = ddim_mod |
| | self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda') |
| | self.alpha = 1 - self.beta |
| | self.alpha_bar = self.alpha.cumprod(dim=0) |
| | self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda') |
| | self.shape = shape |
| |
|
| | def noise(self, x, t): |
| | eps = torch.randn_like(x) * self.std |
| | return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps |
| |
|
| | def ddim_step(self, xt, t, eps): |
| | x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt() |
| | x0 = x0.clamp(-1, 1) |
| | |
| | xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps |
| | return xt_1 |
| |
|
| | def ddpm_step(self, x, eps, t, var=None): |
| | var = self.beta[t] if var is None else var |
| | return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x) |
| |
|
| | def generate(self, model, num_images=16, device="cuda"): |
| | with torch.no_grad(): |
| | x = torch.randn((num_images, *self.shape), device=device) * self.std |
| | for t in range(self.T-1, -1, -self.ddim_mod): |
| | t_tensor = torch.full((num_images,),t, device=device) |
| | epsilons = model(x, t=t_tensor) |
| | x = self.ddim_step(x, t=t, eps=epsilons) |
| | return x |