caixiaoshun commited on
Commit
d0e6306
·
verified ·
1 Parent(s): 020b1da

Create scheduler.py

Browse files
Files changed (1) hide show
  1. scheduler.py +50 -0
scheduler.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class DDPMSchedule(nn.Module):
6
+ def __init__(self, beta_start=1e-4, beta_end=0.02, steps=1000):
7
+ super().__init__()
8
+ beta = torch.linspace(beta_start, beta_end, steps)
9
+ alpha = 1 - beta
10
+ alpha_cumprod = torch.cumprod(alpha, dim=-1)
11
+ self.steps = steps
12
+ self.register_buffer("beta", beta, persistent=False)
13
+ self.register_buffer("alpha", alpha, persistent=False)
14
+ self.register_buffer("alpha_cumprod", alpha_cumprod, persistent=False)
15
+
16
+ def add_noise(
17
+ self, image: torch.Tensor, time: torch.LongTensor, noise: torch.Tensor
18
+ ) -> torch.Tensor:
19
+ alpha_cumprod_t = self.alpha_cumprod[time]
20
+
21
+ while len(alpha_cumprod_t.shape) < len(image.shape):
22
+ alpha_cumprod_t = alpha_cumprod_t.unsqueeze(-1)
23
+
24
+ noise_image = torch.sqrt(alpha_cumprod_t) * image + torch.sqrt(1-alpha_cumprod_t) * noise
25
+ return noise_image
26
+
27
+ def step(self, x_t:torch.Tensor, pred_noise:torch.Tensor, time:torch.LongTensor):
28
+ alpha_t = self.alpha[time]
29
+ alpha_cumprod_t = self.alpha_cumprod[time]
30
+ alpha_cumprod_t_minus_1 = self.alpha_cumprod[time-1]
31
+ beta_t = self.beta[time]
32
+ zero_mask = (time > 0).float()
33
+
34
+ while len(alpha_t.shape) < len(x_t.shape):
35
+ alpha_t = alpha_t.unsqueeze(-1)
36
+ alpha_cumprod_t = alpha_cumprod_t.unsqueeze(-1)
37
+ beta_t = beta_t.unsqueeze(-1)
38
+ zero_mask = zero_mask.unsqueeze(-1)
39
+ alpha_cumprod_t_minus_1 = alpha_cumprod_t_minus_1.unsqueeze(-1)
40
+
41
+ noise_weight = (1-alpha_t)/torch.sqrt(1-alpha_cumprod_t)
42
+ noise = torch.randn_like(x_t)
43
+
44
+
45
+ eta = (1-alpha_cumprod_t_minus_1) / (1-alpha_cumprod_t)
46
+ eta = torch.sqrt(eta * beta_t)
47
+
48
+
49
+ x_t_minus_1 = (x_t - noise_weight * pred_noise) / torch.sqrt(alpha_t) + eta * noise * zero_mask
50
+ return x_t_minus_1