Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| import disvae | |
| import transforms as trans | |
| def load_decode_function(): | |
| P_MODEL = "models/btcvae_celeba" | |
| sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) | |
| vae = disvae.load_model(P_MODEL) | |
| _dec = trans.sequential_function( | |
| sorter.inv, | |
| vae.decoder | |
| ) | |
| def decode(latent): | |
| with torch.no_grad(): | |
| return trans.np_sample(_dec)(latent) | |
| return decode | |
| # GUI ----------------------------------------------------------- | |
| decode = load_decode_function() | |
| latent_vector = np.array([st.slider(f"L{l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) | |
| latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) | |
| value = decode(latent_vector) | |
| value = np.swapaxes(np.swapaxes(value, 0, 2), 0, 1) | |
| st.image(value, use_column_width="always") | |