File size: 4,471 Bytes
61bd5c8
 
cc6024c
61bd5c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import streamlit as st
import hopsworks
import xgboost
import plotly.graph_objs as go
import plotly.express as px
import joblib
import math
import pandas as pd
import os

st.title('๐Ÿ”ฎ Customer Churn Prediction')

st.write(36 * "-")
st.header('\n๐Ÿ“ก Connecting to Hopsworks Feature Store...')


def header(text):
    st.write(36 * "-")
    st.write('#### ' + text)

project = hopsworks.login(project = "annikaij", api_key_value=os.environ['HOPSWORKS_API_KEY'])
fs = project.get_feature_store()

header('๐Ÿช„ Retrieving Feature View...')

feature_view = fs.get_feature_view(
    name="churn_feature_view",
    version=1
    )

st.text('Done โœ…')
header('โš™๏ธ Reading DataFrames from Feature View...')

@st.cache_data()
def retrive_data(feature_view=feature_view):
    feature_view.init_batch_scoring(1)
    batch_data = feature_view.get_batch_data()
    batch_data.drop('customerid', axis=1, inplace=True)
    df_all = feature_view.query.read()
    df_all.drop('churn', axis=1, inplace=True)
    return batch_data, df_all


batch_data, df_all = retrive_data()

st.dataframe(df_all.head())
st.text(f'Shape: {df_all.shape}')
header('๐Ÿ”ฎ Model Retrieving...')


@st.cache_data()
def get_model(project=project):
    mr = project.get_model_registry()
    model = mr.get_model("churnmodel", version=1)
    model_dir = model.download()
    return joblib.load(model_dir + "/churnmodel.pkl")


model = get_model()
st.write(model)


def transform_preds(predictions):
    return ['Churn' if pred == 1 else 'Not Churn' for pred in predictions]


header('๐Ÿ“ Batch Data Prediction...')

st.dataframe(batch_data.head())

predictions = model.predict(batch_data)
predictions = transform_preds(predictions)

df_all['Churn'] = predictions

result_table = df_all[['customerid', 'Churn']]

st.text(f'๐Ÿ‘ฉ๐Ÿปโ€โš–๏ธ Predictions for 5 rows:\n {predictions[:5]}')
header('๐Ÿ’ณ Prediction by Customer Id...')

with st.form(key="Selecting Customer ID"):
    option = st.selectbox(
             'Select a Custimer ID to return a predict.',
             (result_table.customerid.values[:15])
          )
    submit_button = st.form_submit_button(label='Submit')

if submit_button:   
    result = result_table[result_table.customerid == option]['Churn'].values

    st.text(f'๐Ÿ‘ฎ๐Ÿปโ€โ™‚๏ธ Customer ID: {option}')
    st.text(f'๐Ÿ‘ฉ๐Ÿปโ€โš–๏ธ Prediction: {result}')

header('๐Ÿ‘จ๐Ÿปโ€๐ŸŽจ Prediction Visualizing...')


feature_names = batch_data.columns

feature_importance = pd.DataFrame(feature_names, columns=["feature"])
feature_importance["importance"] = model.feature_importances_
feature_importance = feature_importance.sort_values(by=["importance"], ascending=False)

fig_importance = px.bar(
    feature_importance,
    x='feature',
    y='importance',
    title='Feature Importance Plot'
     )

fig_importance.update_xaxes(tickangle=23)
fig_importance.update_xaxes(title="Feature")
fig_importance.update_yaxes(title="Importance")
fig_importance.update_traces(hovertemplate='Feature: %{x} <br>Importance: %{y}') 

st.plotly_chart(fig_importance)


def plot_histogram(data, x_col, title, xlabel, ylabel):

    fig = go.Figure()

    fig = px.histogram(
        data,
        x=x_col,
        color="Churn",
        title=title
    )

    fig.update_xaxes(title=xlabel)
    fig.update_yaxes(title=ylabel)
    fig.update_traces(hovertemplate=xlabel + ': %{x} <br>' + ylabel + ': %{y}')

    return fig


st.plotly_chart(plot_histogram(df_all, 'internetservice', 'Churn rate according to internet service subscribtion', 'Internet service', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'streamingmovies', 'Churn rate according to streaming movies subscribtion', 'Streaming movies', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'streamingtv', 'Churn rate according to internet streaming tv subscribtion', 'Gender', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'gender', 'Churn rate according to Gender', 'Gender', 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'totalcharges', 'Distribution of Total Charges according to Churn/Not', "Charge Value", 'Number of customers'))
st.plotly_chart(plot_histogram(df_all, 'paymentmethod', 'Amount of each Payment Method', "Payment Method", 'Total Amount'))
st.plotly_chart(plot_histogram(df_all, 'partner', 'Affect of having a partner on Churn/Not', "Have a partner", 'Number of customers'))

st.success('๐ŸŽ‰ ๐Ÿ“ˆ ๐Ÿค App Finished Successfully ๐Ÿค ๐Ÿ“ˆ ๐ŸŽ‰')