Spaces:
Runtime error
Runtime error
built space
Browse files- MnistVAEmodel.pt +3 -0
- app.py +35 -0
- model.py +50 -0
- original_5.png +0 -0
- original_8.png +0 -0
- requirements.txt +3 -0
MnistVAEmodel.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d6ab1a824858a37b3dbeffce09cd2de481906e689b4817e505cb2550e992d3d
|
| 3 |
+
size 4796991
|
app.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from model import VariationalAutoEncoder
|
| 3 |
+
from torchvision import transforms
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import gradio as gr
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
INPUT_DIM = 784
|
| 9 |
+
H_DIM = 512
|
| 10 |
+
Z_DIM = 256
|
| 11 |
+
|
| 12 |
+
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM)
|
| 13 |
+
model.load_state_dict(torch.load("MnistVAEmodel.pth"))
|
| 14 |
+
model.eval()
|
| 15 |
+
def predict(img):
|
| 16 |
+
img = img.convert('1')
|
| 17 |
+
img = transforms.ToTensor()(img)
|
| 18 |
+
img = transforms.CenterCrop(size=28)(img)
|
| 19 |
+
print(type(img), img.shape)
|
| 20 |
+
mu, sigma = model.encode(img.view(1, INPUT_DIM))
|
| 21 |
+
|
| 22 |
+
res = []
|
| 23 |
+
for example in range(10):
|
| 24 |
+
epsilon = torch.randn_like(sigma)
|
| 25 |
+
z = mu + sigma * epsilon
|
| 26 |
+
out = model.decode(z)
|
| 27 |
+
out = out.view(-1,1,28,28)
|
| 28 |
+
res.append(transforms.ToPILImage()(out[0]))
|
| 29 |
+
return res
|
| 30 |
+
|
| 31 |
+
title = "Variational-Autoencoder-on-MNIST "
|
| 32 |
+
description = "TO DO"
|
| 33 |
+
examples = ["original_5.png", "original_8.png"]
|
| 34 |
+
gr.Interface(fn=predict, inputs = gr.inputs.Image(), outputs= gr.outputs.Gallery(),
|
| 35 |
+
examples=examples, title=title, description=description).launch(inline=False)
|
model.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class VariationalAutoEncoder(nn.Module):
|
| 6 |
+
# Input image -> hidden dim -> mean, std -> parametirazation trick -> Decoder -> output image
|
| 7 |
+
def __init__(self, inpud_dim, h_dim=200, z_dim=20):
|
| 8 |
+
super().__init__()
|
| 9 |
+
|
| 10 |
+
# encoder
|
| 11 |
+
self.img_2hid = nn.Linear(inpud_dim, h_dim)
|
| 12 |
+
self.hid_2mu = nn.Linear(h_dim, z_dim)
|
| 13 |
+
self.hid_2sigma = nn.Linear(h_dim, z_dim)
|
| 14 |
+
|
| 15 |
+
# decoder
|
| 16 |
+
self.z_2hi = nn.Linear(z_dim, h_dim)
|
| 17 |
+
self.hid_2img = nn.Linear(h_dim, inpud_dim)
|
| 18 |
+
|
| 19 |
+
self.relu = nn.ReLU()
|
| 20 |
+
|
| 21 |
+
def encode(self, x):
|
| 22 |
+
# q_phi(z/x)
|
| 23 |
+
h = self.relu(self.img_2hid(x))
|
| 24 |
+
mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
|
| 25 |
+
|
| 26 |
+
return mu, sigma
|
| 27 |
+
|
| 28 |
+
def decode(self, z):
|
| 29 |
+
# p_theta(x/z)
|
| 30 |
+
h = self.relu(self.z_2hi(z))
|
| 31 |
+
x = self.hid_2img(h)
|
| 32 |
+
return torch.sigmoid(x) # image values should be between zero and one.
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
mu, sigma = self.encode(x)
|
| 36 |
+
# parametirazation trick
|
| 37 |
+
epsilon = torch.randn_like(sigma) # Returns a tensor with the same size as input that is filled with random numbers from a normal distribution with mean 0 and variance 1
|
| 38 |
+
z_reparametrized = mu + sigma * epsilon
|
| 39 |
+
x_reconstructed = self.decode(z_reparametrized)
|
| 40 |
+
return x_reconstructed, mu, sigma # 2 parts of loss: 1- mu, sigma pushed to normal distribution. 2 the x_reconstructed should be same as x
|
| 41 |
+
|
| 42 |
+
if __name__ == "__main__":
|
| 43 |
+
|
| 44 |
+
x = torch.randn(4,28*28)
|
| 45 |
+
vae = VariationalAutoEncoder(inpud_dim=784)
|
| 46 |
+
x_reconstructed, mu, sigma = vae(x)
|
| 47 |
+
print(x_reconstructed.shape)
|
| 48 |
+
print(mu.shape)
|
| 49 |
+
print(sigma.shape)
|
| 50 |
+
print(torch.mean(mu))
|
original_5.png
ADDED
|
original_8.png
ADDED
|
requirements.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|