nitesh501 commited on
Commit
28c8f91
·
verified ·
1 Parent(s): 45f73cb

Update sampler.py

Browse files
Files changed (1) hide show
  1. sampler.py +8 -18
sampler.py CHANGED
@@ -1,14 +1,13 @@
1
  import torch
2
- import torch.nn.functional as F
3
 
4
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
5
 
 
6
  def extract(a, t, x_shape):
7
- batch_size = t.shape[0]
8
- t = t.to(a.device)
9
- out = a.gather(-1, t)
10
- out = out.to(t.device)
11
- return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
12
 
13
  beta_start = 1e-4
14
  beta_end = 0.02
@@ -21,28 +20,20 @@ betas = torch.linspace(
21
  dtype=torch.float32,
22
  device=device,
23
  )
 
24
  alphas = 1.0 - betas
25
  alphas_cumprod = torch.cumprod(alphas, dim=0)
26
 
27
- alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
28
- sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
29
- sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
30
-
31
- sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
32
-
33
- posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod))
34
 
35
  @torch.no_grad()
36
  def ddim_sample(model, x, t, t_prev, labels=None):
37
-
38
  alpha_bar_t = extract(alphas_cumprod, t, x.shape)
39
  alpha_bar_prev = extract(alphas_cumprod, t_prev, x.shape)
40
 
41
  pred_noise = model(x, t, labels)
42
 
43
  x0 = (
44
- x
45
- - torch.sqrt(1 - alpha_bar_t) * pred_noise
46
  ) / torch.sqrt(alpha_bar_t)
47
 
48
  x_prev = (
@@ -50,5 +41,4 @@ def ddim_sample(model, x, t, t_prev, labels=None):
50
  + torch.sqrt(1 - alpha_bar_prev) * pred_noise
51
  )
52
 
53
- return x_prev
54
-
 
1
  import torch
 
2
 
3
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
4
 
5
+
6
  def extract(a, t, x_shape):
7
+ batch_size = t.shape[0]
8
+ out = a.gather(0, t.to(a.device))
9
+ return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
10
+
 
11
 
12
  beta_start = 1e-4
13
  beta_end = 0.02
 
20
  dtype=torch.float32,
21
  device=device,
22
  )
23
+
24
  alphas = 1.0 - betas
25
  alphas_cumprod = torch.cumprod(alphas, dim=0)
26
 
 
 
 
 
 
 
 
27
 
28
  @torch.no_grad()
29
  def ddim_sample(model, x, t, t_prev, labels=None):
 
30
  alpha_bar_t = extract(alphas_cumprod, t, x.shape)
31
  alpha_bar_prev = extract(alphas_cumprod, t_prev, x.shape)
32
 
33
  pred_noise = model(x, t, labels)
34
 
35
  x0 = (
36
+ x - torch.sqrt(1 - alpha_bar_t) * pred_noise
 
37
  ) / torch.sqrt(alpha_bar_t)
38
 
39
  x_prev = (
 
41
  + torch.sqrt(1 - alpha_bar_prev) * pred_noise
42
  )
43
 
44
+ return x_prev