Spaces:
Runtime error
Runtime error
| import torch | |
| from model import VariationalAutoEncoder | |
| from torchvision import transforms | |
| from PIL import Image | |
| import gradio as gr | |
| INPUT_DIM = 784 | |
| H_DIM = 512 | |
| Z_DIM = 256 | |
| model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM) | |
| model.load_state_dict(torch.load("MnistVAEmodel.pt")) | |
| model.eval() | |
| def predict(img): | |
| img = transforms.ToTensor()(img) | |
| mu, sigma = model.encode(img.view(1, INPUT_DIM)) | |
| res = [] | |
| for example in range(5): | |
| epsilon = torch.randn_like(sigma) | |
| z = mu + sigma * epsilon | |
| out = model.decode(z) | |
| out = out.view(-1,1,28,28) | |
| res.append(transforms.ToPILImage()(out[0])) | |
| return res | |
| title = "Variational-Autoencoder-on-MNIST " | |
| description = "TO DO" | |
| examples = ["original_8.png"] | |
| gr.Interface(fn=predict, inputs = gr.inputs.Image(shape=(28,28), image_mode="L"), outputs= gr.Gallery(), | |
| examples=examples, title=title, description=description).launch(inline=False) |