Spaces:
Runtime error
Runtime error
| import torch | |
| from torch import nn | |
| class VariationalAutoEncoder(nn.Module): | |
| # Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image | |
| def __init__(self, inpud_dim, h_dim=200, z_dim=20): | |
| super().__init__() | |
| # encoder | |
| self.img_2hid = nn.Linear(inpud_dim, h_dim) | |
| self.hid_2mu = nn.Linear(h_dim, z_dim) | |
| self.hid_2sigma = nn.Linear(h_dim, z_dim) | |
| # decoder | |
| self.z_2hi = nn.Linear(z_dim, h_dim) | |
| self.hid_2img = nn.Linear(h_dim, inpud_dim) | |
| self.relu = nn.ReLU() | |
| def encode(self, x): | |
| # q_phi(z/x) | |
| h = self.relu(self.img_2hid(x)) | |
| mu, sigma = self.hid_2mu(h), self.hid_2sigma(h) | |
| return mu, sigma | |
| def decode(self, z): | |
| # p_theta(x/z) | |
| h = self.relu(self.z_2hi(z)) | |
| x = self.hid_2img(h) | |
| return torch.sigmoid(x) # image values should be between zero and one. | |
| def forward(self, x): | |
| mu, sigma = self.encode(x) | |
| # parametirazation trick | |
| epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1 | |
| z_reparametrized = mu + sigma * epsilon | |
| x_reconstructed = self.decode(z_reparametrized) | |
| return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x | |
| if __name__ == "__main__": | |
| x = torch.randn(4,28*28) | |
| vae = VariationalAutoEncoder(inpud_dim=784) | |
| x_reconstructed, mu, sigma = vae(x) | |
| print(x_reconstructed.shape) | |
| print(mu.shape) | |
| print(sigma.shape) | |
| print(torch.mean(mu)) | |