BiliSakura commited on
Commit
3b98545
·
verified ·
1 Parent(s): 073ca8f

Update all files for EO-VAE

Browse files
Files changed (1) hide show
  1. _eo_vae/distributions.py +22 -0
_eo_vae/distributions.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License - Based on CompVis latent-diffusion / diffusers
2
+ # Diagonal Gaussian for VAE latent distribution
3
+
4
+ import torch
5
+
6
+
7
+ class DiagonalGaussianDistribution:
8
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False) -> None:
9
+ self.parameters = parameters
10
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
11
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
12
+ self.deterministic = deterministic
13
+ self.std = torch.exp(0.5 * self.logvar)
14
+ self.var = torch.exp(self.logvar)
15
+ if self.deterministic:
16
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device)
17
+
18
+ def sample(self) -> torch.Tensor:
19
+ return self.mean + self.std * torch.randn_like(self.mean, device=self.parameters.device)
20
+
21
+ def mode(self) -> torch.Tensor:
22
+ return self.mean