Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import streamlit as st | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import disvae | |
| import transforms as trans | |
| P_MODEL = "model/drilling_ds_btcvae" | |
| SAMPLING_TIME = 0.15 | |
| st.set_page_config(page_title="Drilling VAE") | |
| def load_decode_function(): | |
| sorter = trans.LatentSorter(disvae.get_kl_dict(P_MODEL)) | |
| vae = disvae.load_model(P_MODEL) | |
| scaler = trans.MinMaxScaler(_min=torch.tensor([1.3]),_max=torch.tensor([4.0]),min_norm=0.3,max_norm=0.6) | |
| imaging = trans.SumField() | |
| _dec = trans.sequential_function( | |
| sorter.inv, | |
| vae.decoder, | |
| scaler.inv, | |
| imaging.inv | |
| ) | |
| def decode(latent): | |
| with torch.no_grad(): | |
| return trans.np_sample(_dec)(latent) | |
| return decode | |
| decode = load_decode_function() | |
| col1,col2 = st.columns(2) | |
| with col1: | |
| st.markdown("**Latent Space Parameters**") | |
| latent_vector = np.array([st.slider(f"Latent Dimension {l}",min_value=-3.0,max_value=3.0,value=0.0) for l in range(3)]) | |
| latent_vector = np.concatenate([latent_vector,np.zeros(7)],axis=0) | |
| ts = decode(latent_vector) | |
| with col2: | |
| st.markdown("**Generated Time Series**") | |
| fig, ax = plt.subplots(figsize=(4,3)) | |
| time = np.arange(0,len(ts)*SAMPLING_TIME,SAMPLING_TIME) | |
| ax.plot(time,ts.ravel()) | |
| ax.set_xlabel("Time t [s]") | |
| ax.set_ylabel("Spindle torque t [Nm]") | |
| ax.set_ylim([0,4]) | |
| ax.grid() | |
| st.pyplot(fig) |