Spaces:
Runtime error
Runtime error
Create scheduler.py
Browse files- scheduler.py +50 -0
scheduler.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DDPMSchedule(nn.Module):
|
| 6 |
+
def __init__(self, beta_start=1e-4, beta_end=0.02, steps=1000):
|
| 7 |
+
super().__init__()
|
| 8 |
+
beta = torch.linspace(beta_start, beta_end, steps)
|
| 9 |
+
alpha = 1 - beta
|
| 10 |
+
alpha_cumprod = torch.cumprod(alpha, dim=-1)
|
| 11 |
+
self.steps = steps
|
| 12 |
+
self.register_buffer("beta", beta, persistent=False)
|
| 13 |
+
self.register_buffer("alpha", alpha, persistent=False)
|
| 14 |
+
self.register_buffer("alpha_cumprod", alpha_cumprod, persistent=False)
|
| 15 |
+
|
| 16 |
+
def add_noise(
|
| 17 |
+
self, image: torch.Tensor, time: torch.LongTensor, noise: torch.Tensor
|
| 18 |
+
) -> torch.Tensor:
|
| 19 |
+
alpha_cumprod_t = self.alpha_cumprod[time]
|
| 20 |
+
|
| 21 |
+
while len(alpha_cumprod_t.shape) < len(image.shape):
|
| 22 |
+
alpha_cumprod_t = alpha_cumprod_t.unsqueeze(-1)
|
| 23 |
+
|
| 24 |
+
noise_image = torch.sqrt(alpha_cumprod_t) * image + torch.sqrt(1-alpha_cumprod_t) * noise
|
| 25 |
+
return noise_image
|
| 26 |
+
|
| 27 |
+
def step(self, x_t:torch.Tensor, pred_noise:torch.Tensor, time:torch.LongTensor):
|
| 28 |
+
alpha_t = self.alpha[time]
|
| 29 |
+
alpha_cumprod_t = self.alpha_cumprod[time]
|
| 30 |
+
alpha_cumprod_t_minus_1 = self.alpha_cumprod[time-1]
|
| 31 |
+
beta_t = self.beta[time]
|
| 32 |
+
zero_mask = (time > 0).float()
|
| 33 |
+
|
| 34 |
+
while len(alpha_t.shape) < len(x_t.shape):
|
| 35 |
+
alpha_t = alpha_t.unsqueeze(-1)
|
| 36 |
+
alpha_cumprod_t = alpha_cumprod_t.unsqueeze(-1)
|
| 37 |
+
beta_t = beta_t.unsqueeze(-1)
|
| 38 |
+
zero_mask = zero_mask.unsqueeze(-1)
|
| 39 |
+
alpha_cumprod_t_minus_1 = alpha_cumprod_t_minus_1.unsqueeze(-1)
|
| 40 |
+
|
| 41 |
+
noise_weight = (1-alpha_t)/torch.sqrt(1-alpha_cumprod_t)
|
| 42 |
+
noise = torch.randn_like(x_t)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
eta = (1-alpha_cumprod_t_minus_1) / (1-alpha_cumprod_t)
|
| 46 |
+
eta = torch.sqrt(eta * beta_t)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
x_t_minus_1 = (x_t - noise_weight * pred_noise) / torch.sqrt(alpha_t) + eta * noise * zero_mask
|
| 50 |
+
return x_t_minus_1
|