Update noise_scheduler.py
Browse files- noise_scheduler.py +44 -45
noise_scheduler.py
CHANGED
|
@@ -1,46 +1,45 @@
|
|
| 1 |
-
import numpy as np
|
| 2 |
-
import torch
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
-
|
| 9 |
-
-
|
| 10 |
-
-
|
| 11 |
-
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
self.
|
| 16 |
-
self.
|
| 17 |
-
self.
|
| 18 |
-
self.
|
| 19 |
-
self.
|
| 20 |
-
self.
|
| 21 |
-
self.
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
x0 = (
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
xt_1
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
x = self.ddim_step(x, t=t, eps=epsilons)
|
| 46 |
return x
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class NoiseSchedule:
|
| 5 |
+
"""
|
| 6 |
+
Handles:
|
| 7 |
+
- DDIM inference (with a ddim_mod to skip steps)
|
| 8 |
+
- DDPM inference
|
| 9 |
+
- Forward Noising
|
| 10 |
+
- Linear beta schedule
|
| 11 |
+
- Classifier Free Guidance (w is a hyperparameter for cfg schedule)
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
|
| 14 |
+
self.T = T
|
| 15 |
+
self.std = std
|
| 16 |
+
self.ddim_mod = ddim_mod
|
| 17 |
+
self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
|
| 18 |
+
self.alpha = 1 - self.beta
|
| 19 |
+
self.alpha_bar = self.alpha.cumprod(dim=0)
|
| 20 |
+
self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
|
| 21 |
+
self.shape = shape
|
| 22 |
+
|
| 23 |
+
def noise(self, x, t):
|
| 24 |
+
eps = torch.randn_like(x) * self.std
|
| 25 |
+
return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps
|
| 26 |
+
|
| 27 |
+
def ddim_step(self, xt, t, eps):
|
| 28 |
+
x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
|
| 29 |
+
x0 = x0.clamp(-1, 1)
|
| 30 |
+
# note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
|
| 31 |
+
xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
|
| 32 |
+
return xt_1
|
| 33 |
+
|
| 34 |
+
def ddpm_step(self, x, eps, t, var=None):
|
| 35 |
+
var = self.beta[t] if var is None else var
|
| 36 |
+
return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)
|
| 37 |
+
|
| 38 |
+
def generate(self, model, num_images=16, device="cuda"):
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
x = torch.randn((num_images, *self.shape), device=device) * self.std
|
| 41 |
+
for t in range(self.T-1, -1, -self.ddim_mod):
|
| 42 |
+
t_tensor = torch.full((num_images,),t, device=device)
|
| 43 |
+
epsilons = model(x, t=t_tensor)
|
| 44 |
+
x = self.ddim_step(x, t=t, eps=epsilons)
|
|
|
|
| 45 |
return x
|