Update sampler.py
Browse files- sampler.py +8 -18
sampler.py
CHANGED
|
@@ -1,14 +1,13 @@
|
|
| 1 |
import torch
|
| 2 |
-
import torch.nn.functional as F
|
| 3 |
|
| 4 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 5 |
|
|
|
|
| 6 |
def extract(a, t, x_shape):
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
|
| 12 |
|
| 13 |
beta_start = 1e-4
|
| 14 |
beta_end = 0.02
|
|
@@ -21,28 +20,20 @@ betas = torch.linspace(
|
|
| 21 |
dtype=torch.float32,
|
| 22 |
device=device,
|
| 23 |
)
|
|
|
|
| 24 |
alphas = 1.0 - betas
|
| 25 |
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 26 |
|
| 27 |
-
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
| 28 |
-
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 29 |
-
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
| 30 |
-
|
| 31 |
-
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
|
| 32 |
-
|
| 33 |
-
posterior_variance = (betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod))
|
| 34 |
|
| 35 |
@torch.no_grad()
|
| 36 |
def ddim_sample(model, x, t, t_prev, labels=None):
|
| 37 |
-
|
| 38 |
alpha_bar_t = extract(alphas_cumprod, t, x.shape)
|
| 39 |
alpha_bar_prev = extract(alphas_cumprod, t_prev, x.shape)
|
| 40 |
|
| 41 |
pred_noise = model(x, t, labels)
|
| 42 |
|
| 43 |
x0 = (
|
| 44 |
-
x
|
| 45 |
-
- torch.sqrt(1 - alpha_bar_t) * pred_noise
|
| 46 |
) / torch.sqrt(alpha_bar_t)
|
| 47 |
|
| 48 |
x_prev = (
|
|
@@ -50,5 +41,4 @@ def ddim_sample(model, x, t, t_prev, labels=None):
|
|
| 50 |
+ torch.sqrt(1 - alpha_bar_prev) * pred_noise
|
| 51 |
)
|
| 52 |
|
| 53 |
-
return x_prev
|
| 54 |
-
|
|
|
|
| 1 |
import torch
|
|
|
|
| 2 |
|
| 3 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 4 |
|
| 5 |
+
|
| 6 |
def extract(a, t, x_shape):
|
| 7 |
+
batch_size = t.shape[0]
|
| 8 |
+
out = a.gather(0, t.to(a.device))
|
| 9 |
+
return out.reshape(batch_size, *((1,) * (len(x_shape) - 1)))
|
| 10 |
+
|
|
|
|
| 11 |
|
| 12 |
beta_start = 1e-4
|
| 13 |
beta_end = 0.02
|
|
|
|
| 20 |
dtype=torch.float32,
|
| 21 |
device=device,
|
| 22 |
)
|
| 23 |
+
|
| 24 |
alphas = 1.0 - betas
|
| 25 |
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
@torch.no_grad()
|
| 29 |
def ddim_sample(model, x, t, t_prev, labels=None):
|
|
|
|
| 30 |
alpha_bar_t = extract(alphas_cumprod, t, x.shape)
|
| 31 |
alpha_bar_prev = extract(alphas_cumprod, t_prev, x.shape)
|
| 32 |
|
| 33 |
pred_noise = model(x, t, labels)
|
| 34 |
|
| 35 |
x0 = (
|
| 36 |
+
x - torch.sqrt(1 - alpha_bar_t) * pred_noise
|
|
|
|
| 37 |
) / torch.sqrt(alpha_bar_t)
|
| 38 |
|
| 39 |
x_prev = (
|
|
|
|
| 41 |
+ torch.sqrt(1 - alpha_bar_prev) * pred_noise
|
| 42 |
)
|
| 43 |
|
| 44 |
+
return x_prev
|
|
|