Yash Nagraj commited on
Commit
95dc5d6
·
1 Parent(s): e6092ac

Finished the Sampler

Browse files
Files changed (1) hide show
  1. diffusion.py +46 -1
diffusion.py CHANGED
@@ -59,10 +59,55 @@ class GaussianDiffusionSampler(nn.Module):
59
  )
60
  self.alphas = 1 - self.betas
61
  self.beta_alphas = torch.cumprod(self.alphas,dim=0)
62
-
63
  """
64
  This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor.
65
  The resulting tensor is stored in self.beta_alphas_prev.
66
  """
67
  self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T]
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  )
60
  self.alphas = 1 - self.betas
61
  self.beta_alphas = torch.cumprod(self.alphas,dim=0)
62
+
63
  """
64
  This line of code pads the tensor self.beta_alphas by adding a single element with the value 1 to the beginning of the tensor.
65
  The resulting tensor is stored in self.beta_alphas_prev.
66
  """
67
  self.beta_alphas_prev = F.pad(self.beta_alphas,[1,0],value=1)[:T]
68
 
69
+ self.register_buffer(
70
+ "coeff1",
71
+ (1 / torch.sqrt(self.alphas))
72
+ )
73
+
74
+ self.register_buffer(
75
+ "coeff2",
76
+ self.coeff1 * ((1- self.alphas) / (torch.sqrt(1-self.beta_alphas)))
77
+ )
78
+
79
+ self.register_buffer(
80
+ "posterior_coeff",
81
+ (1 - self.beta_alphas_prev) / (1-self.beta_alphas) * self.betas
82
+ )
83
+
84
+ def pred_xt_prev_mean_from_eps(self,x_t,t,eps):
85
+ return (
86
+ extract(self.coeff1,t,x_t.shape) * x_t -
87
+ extract(self.coeff2,t,x_t.shape) * eps
88
+ )
89
+
90
+ def p_mean_variance(self,x_t,t):
91
+ var = torch.cat([self.posterior_coeff[1:2],self.betas[1:]])
92
+ var = extract(var,t,x_t.shape)
93
+
94
+ eps = self.model(x_t,t)
95
+ xt_prev_mean = self.pred_xt_prev_mean_from_eps(x_t,t,eps)
96
+ return xt_prev_mean,var
97
+
98
+ def forward(self,x_T):
99
+ x_t=x_T
100
+ for timestep in reversed(range(self.T)):
101
+ print(f"Sampling timestep: {timestep}")
102
+
103
+ t = x_t.new_ones([x_t.shape[0],], dtype=torch.long) * timestep
104
+ mean, var = self.p_mean_variance(x_t,t)
105
+ if timestep > 0:
106
+ noise = torch.randn_like(x_t)
107
+ else:
108
+ noise = 0
109
+ x_t = mean + torch.sqrt(var) * noise
110
+
111
+ x_0 = x_t
112
+ return torch.clip(x_0,-1,1)
113
+