nitesh501 commited on
Commit
d3ae734
·
verified ·
1 Parent(s): 447c1eb

Delete sampler.py

Browse files
Files changed (1) hide show
  1. sampler.py +0 -44
sampler.py DELETED
@@ -1,44 +0,0 @@
1
- import torch
2
-
3
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
4
-
5
-
6
- def extract(a, t, x_shape):
7
- batch_size = t.shape[0]
8
- out = a.gather(0, t.to(a.device))
9
- return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
10
-
11
-
12
- beta_start = 1e-4
13
- beta_end = 0.02
14
- num_timesteps = 1000
15
-
16
- betas = torch.linspace(
17
- beta_start,
18
- beta_end,
19
- num_timesteps,
20
- dtype=torch.float32,
21
- device=device,
22
- )
23
-
24
- alphas = 1.0 - betas
25
- alphas_cumprod = torch.cumprod(alphas, dim=0)
26
-
27
-
28
- @torch.no_grad()
29
- def ddim_sample(model, x, t, t_prev, labels=None):
30
- alpha_bar_t = extract(alphas_cumprod, t, x.shape)
31
- alpha_bar_prev = extract(alphas_cumprod, t_prev, x.shape)
32
-
33
- pred_noise = model(x, t, labels)
34
-
35
- x0 = (
36
- x - torch.sqrt(1 - alpha_bar_t) * pred_noise
37
- ) / torch.sqrt(alpha_bar_t)
38
-
39
- x_prev = (
40
- torch.sqrt(alpha_bar_prev) * x0
41
- + torch.sqrt(1 - alpha_bar_prev) * pred_noise
42
- )
43
-
44
- return x_prev