Spaces:
Sleeping
Sleeping
File size: 4,471 Bytes
61bd5c8 cc6024c 61bd5c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import streamlit as st
import hopsworks
import xgboost
import plotly.graph_objs as go
import plotly.express as px
import joblib
import math
import pandas as pd
import os
st.title('๐ฎ Customer Churn Prediction')
st.write(36 * "-")
st.header('\n๐ก Connecting to Hopsworks Feature Store...')
def header(text):
st.write(36 * "-")
st.write('#### ' + text)
project = hopsworks.login(project = "annikaij", api_key_value=os.environ['HOPSWORKS_API_KEY'])
fs = project.get_feature_store()
header('๐ช Retrieving Feature View...')
feature_view = fs.get_feature_view(
name="churn_feature_view",
version=1
)
st.text('Done โ
')
header('โ๏ธ Reading DataFrames from Feature View...')
@st.cache_data()
def retrive_data(feature_view=feature_view):
feature_view.init_batch_scoring(1)
batch_data = feature_view.get_batch_data()
batch_data.drop('customerid', axis=1, inplace=True)
df_all = feature_view.query.read()
df_all.drop('churn', axis=1, inplace=True)
return batch_data, df_all
batch_data, df_all = retrive_data()
st.dataframe(df_all.head())
st.text(f'Shape: {df_all.shape}')
header('๐ฎ Model Retrieving...')
@st.cache_data()
def get_model(project=project):
mr = project.get_model_registry()
model = mr.get_model("churnmodel", version=1)
model_dir = model.download()
return joblib.load(model_dir + "/churnmodel.pkl")
model = get_model()
st.write(model)
def transform_preds(predictions):
return ['Churn' if pred == 1 else 'Not Churn' for pred in predictions]
header('๐ Batch Data Prediction...')
st.dataframe(batch_data.head())
predictions = model.predict(batch_data)
predictions = transform_preds(predictions)
df_all['Churn'] = predictions
result_table = df_all[['customerid', 'Churn']]
st.text(f'๐ฉ๐ปโโ๏ธ Predictions for 5 rows:\n {predictions[:5]}')
header('๐ณ Prediction by Customer Id...')
with st.form(key="Selecting Customer ID"):
option = st.selectbox(
'Select a Custimer ID to return a predict.',
(result_table.customerid.values[:15])
)
submit_button = st.form_submit_button(label='Submit')
if submit_button:
result = result_table[result_table.customerid == option]['Churn'].values
st.text(f'๐ฎ๐ปโโ๏ธ Customer ID: {option}')
st.text(f'๐ฉ๐ปโโ๏ธ Prediction: {result}')
header('๐จ๐ปโ๐จ Prediction Visualizing...')
feature_names = batch_data.columns
feature_importance = pd.DataFrame(feature_names, columns=["feature"])
feature_importance["importance"] = model.feature_importances_
feature_importance = feature_importance.sort_values(by=["importance"], ascending=False)
fig_importance = px.bar(
feature_importance,
x='feature',
y='importance',
title='Feature Importance Plot'
)
fig_importance.update_xaxes(tickangle=23)
fig_importance.update_xaxes(title="Feature")
fig_importance.update_yaxes(title="Importance")
fig_importance.update_traces(hovertemplate='Feature: %{x} <br>Importance: %{y}')
st.plotly_chart(fig_importance)
def plot_histogram(data, x_col, title, xlabel, ylabel):
fig = go.Figure()
fig = px.histogram(
data,
x=x_col,
color="Churn",
title=title
)
fig.update_xaxes(title=xlabel)
fig.update_yaxes(title=ylabel)
fig.update_traces(hovertemplate=xlabel + ': %{x} <br>' + ylabel + ': %{y}')
return fig
st.plotly_chart(plot_histogram(df_all, 'internetservice', 'Churn rate according to internet service subscribtion', 'Internet service', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'streamingmovies', 'Churn rate according to streaming movies subscribtion', 'Streaming movies', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'streamingtv', 'Churn rate according to internet streaming tv subscribtion', 'Gender', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'gender', 'Churn rate according to Gender', 'Gender', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'totalcharges', 'Distribution of Total Charges according to Churn/Not', "Charge Value", 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'paymentmethod', 'Amount of each Payment Method', "Payment Method", 'Total Amount'))
st.plotly_chart(plot_histogram(df_all, 'partner', 'Affect of having a partner on Churn/Not', "Have a partner", 'Number of customers'))
st.success('๐ ๐ ๐ค App Finished Successfully ๐ค ๐ ๐') |