added files
Browse files
app.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from normflows import nflow
|
| 4 |
+
import numpy as np
|
| 5 |
+
import seaborn as sns
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
uploaded_file = st.file_uploader("Choose original dataset")
|
| 9 |
+
bw = st.number_input('Scale',value=3.05)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute():
|
| 14 |
+
api = nflow(dim=8,latent=16,dataset=uploaded_file)
|
| 15 |
+
api.compile(optim=torch.optim.ASGD,bw=bw,lr=0.0001,wd=None)
|
| 16 |
+
api.train(iters=10000)
|
| 17 |
+
samples = np.array(api.model.sample(
|
| 18 |
+
torch.tensor(api.scaled).float()).detach())
|
| 19 |
+
|
| 20 |
+
# fig, ax = plt.subplots()
|
| 21 |
+
g = sns.jointplot(x=samples[:, 0], y=samples[:, 1], kind='kde',cmap=sns.color_palette("Blues", as_cmap=True),fill=True,label='Gaussian KDE',levels=50)
|
| 22 |
+
|
| 23 |
+
w = sns.scatterplot(x=api.scaled[:,0],y=api.scaled[:,1],ax=g.ax_joint,c='orange',marker='+',s=100,label='Real')
|
| 24 |
+
st.pyplot(w.get_figure())
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def random_normal_samples(n, dim=2):
|
| 28 |
+
return torch.zeros(n, dim).normal_(mean=0, std=1)
|
| 29 |
+
|
| 30 |
+
samples = np.array(api.model.sample(torch.tensor(random_normal_samples(1000,api.scaled.shape[-1])).float()).detach())
|
| 31 |
+
|
| 32 |
+
return api.scaler.inverse_transform(samples)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
if uploaded_file is not None:
|
| 37 |
+
samples=compute()
|
| 38 |
+
st.download_button('Download generated CSV', pd.DataFrame(samples).to_csv(), 'text/csv')
|