detectivejoewest commited on
Commit
b384f6d
·
verified ·
1 Parent(s): 6c662fd

Update noise_scheduler.py

Browse files
Files changed (1) hide show
  1. noise_scheduler.py +44 -45
noise_scheduler.py CHANGED
@@ -1,46 +1,45 @@
1
- import numpy as np
2
- import torch
3
- from IPython.display import clear_output
4
-
5
- class NoiseSchedule:
6
- """
7
- Handles:
8
- - DDIM inference (with a ddim_mod to skip steps)
9
- - DDPM inference
10
- - Forward Noising
11
- - Linear beta schedule
12
- - Classifier Free Guidance (w is a hyperparameter for cfg schedule)
13
- """
14
- def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
15
- self.T = T
16
- self.std = std
17
- self.ddim_mod = ddim_mod
18
- self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
19
- self.alpha = 1 - self.beta
20
- self.alpha_bar = self.alpha.cumprod(dim=0)
21
- self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
22
- self.shape = shape
23
-
24
- def noise(self, x, t):
25
- eps = torch.randn_like(x) * self.std
26
- return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps
27
-
28
- def ddim_step(self, xt, t, eps):
29
- x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
30
- x0 = x0.clamp(-1, 1)
31
- # note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
32
- xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
33
- return xt_1
34
-
35
- def ddpm_step(self, x, eps, t, var=None):
36
- var = self.beta[t] if var is None else var
37
- return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)
38
-
39
- def generate(self, model, num_images=16, device="cuda"):
40
- with torch.no_grad():
41
- x = torch.randn((num_images, *self.shape), device=device) * self.std
42
- for t in range(self.T-1, -1, -self.ddim_mod):
43
- t_tensor = torch.full((num_images,),t, device=device)
44
- epsilons = model(x, t=t_tensor)
45
- x = self.ddim_step(x, t=t, eps=epsilons)
46
  return x
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ class NoiseSchedule:
5
+ """
6
+ Handles:
7
+ - DDIM inference (with a ddim_mod to skip steps)
8
+ - DDPM inference
9
+ - Forward Noising
10
+ - Linear beta schedule
11
+ - Classifier Free Guidance (w is a hyperparameter for cfg schedule)
12
+ """
13
+ def __init__(self, T, std=1, shape=(4, 64, 64), ddim_mod=10, trainer_mode=False):
14
+ self.T = T
15
+ self.std = std
16
+ self.ddim_mod = ddim_mod
17
+ self.beta = torch.tensor(np.linspace(1e-4, 0.02, T), dtype=torch.float32, device='cpu' if trainer_mode else 'cuda')
18
+ self.alpha = 1 - self.beta
19
+ self.alpha_bar = self.alpha.cumprod(dim=0)
20
+ self.w = torch.full((T,), 7.5, device='cpu' if trainer_mode else 'cuda')
21
+ self.shape = shape
22
+
23
+ def noise(self, x, t):
24
+ eps = torch.randn_like(x) * self.std
25
+ return (self.alpha_bar[t]**0.5) * x + ((1-self.alpha_bar[t])**0.5) * eps, eps
26
+
27
+ def ddim_step(self, xt, t, eps):
28
+ x0 = (xt - (1 - self.alpha_bar[t]).sqrt() * eps) / self.alpha_bar[t].sqrt()
29
+ x0 = x0.clamp(-1, 1)
30
+ # note that eps = (xt - sqrt(abar[t]) * x0) / sqrt(1 - abar[t])
31
+ xt_1 = self.alpha_bar[max(0,t - self.ddim_mod)].sqrt() * x0 + (1 - self.alpha_bar[max(0,t - self.ddim_mod)]).sqrt() * eps
32
+ return xt_1
33
+
34
+ def ddpm_step(self, x, eps, t, var=None):
35
+ var = self.beta[t] if var is None else var
36
+ return (self.alpha[t]**-0.5) * (x - ((1 - self.alpha_bar[t])**0.5) * eps) + var * torch.randn_like(x)
37
+
38
+ def generate(self, model, num_images=16, device="cuda"):
39
+ with torch.no_grad():
40
+ x = torch.randn((num_images, *self.shape), device=device) * self.std
41
+ for t in range(self.T-1, -1, -self.ddim_mod):
42
+ t_tensor = torch.full((num_images,),t, device=device)
43
+ epsilons = model(x, t=t_tensor)
44
+ x = self.ddim_step(x, t=t, eps=epsilons)
 
45
  return x