grow / app.py
toandaominh1997's picture
update demo
dcd3022
import pandas as pd
import streamlit as st
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
from joblib import dump, load
from pathlib import Path
st.title("Iris Inference")
cols = st.columns(2)
with cols[0]:
st.header('Input Features')
st.number_input('sepal length in cm', key = 'sl')
st.number_input('sepal width in cm', key = 'sw')
st.number_input('petal length in cm', key = 'pl')
st.number_input('pepal length in cm', key = 'pw')
data = pd.DataFrame.from_dict({'sepal length (cm)': [float(st.session_state.sl)],
'sepal width (cm)': [float(st.session_state.sw)],
'petal length (cm)': [float(st.session_state.pl)],
'petal width (cm)': [float(st.session_state.pw)]})
st.button('Submit')
def load_model():
path_model = Path(Path.cwd(), 'weights/rf.joblib')
model = load(str(path_model))
return model
model = load_model()
@st.cache
def predict_iris(data):
return model.predict(data)[0]
with cols[1]:
st.markdown("## Predictors: ")
st.write("Class: ", predict_iris(data))