Annikaijak commited on
Commit
61bd5c8
ยท
verified ยท
1 Parent(s): 4bf406d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -0
app.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import hopsworks
3
+ import plotly.graph_objs as go
4
+ import plotly.express as px
5
+ import joblib
6
+ import math
7
+ import pandas as pd
8
+ import os
9
+
10
+ st.title('๐Ÿ”ฎ Customer Churn Prediction')
11
+
12
+ st.write(36 * "-")
13
+ st.header('\n๐Ÿ“ก Connecting to Hopsworks Feature Store...')
14
+
15
+
16
+ def header(text):
17
+ st.write(36 * "-")
18
+ st.write('#### ' + text)
19
+
20
+ project = hopsworks.login(project = "annikaij", api_key_value=os.environ['HOPSWORKS_API_KEY'])
21
+ fs = project.get_feature_store()
22
+
23
+ header('๐Ÿช„ Retrieving Feature View...')
24
+
25
+ feature_view = fs.get_feature_view(
26
+ name="churn_feature_view",
27
+ version=1
28
+ )
29
+
30
+ st.text('Done โœ…')
31
+ header('โš™๏ธ Reading DataFrames from Feature View...')
32
+
33
+ @st.cache_data()
34
+ def retrive_data(feature_view=feature_view):
35
+ feature_view.init_batch_scoring(1)
36
+ batch_data = feature_view.get_batch_data()
37
+ batch_data.drop('customerid', axis=1, inplace=True)
38
+ df_all = feature_view.query.read()
39
+ df_all.drop('churn', axis=1, inplace=True)
40
+ return batch_data, df_all
41
+
42
+
43
+ batch_data, df_all = retrive_data()
44
+
45
+ st.dataframe(df_all.head())
46
+ st.text(f'Shape: {df_all.shape}')
47
+ header('๐Ÿ”ฎ Model Retrieving...')
48
+
49
+
50
+ @st.cache_data()
51
+ def get_model(project=project):
52
+ mr = project.get_model_registry()
53
+ model = mr.get_model("churnmodel", version=1)
54
+ model_dir = model.download()
55
+ return joblib.load(model_dir + "/churnmodel.pkl")
56
+
57
+
58
+ model = get_model()
59
+ st.write(model)
60
+
61
+
62
+ def transform_preds(predictions):
63
+ return ['Churn' if pred == 1 else 'Not Churn' for pred in predictions]
64
+
65
+
66
+ header('๐Ÿ“ Batch Data Prediction...')
67
+
68
+ st.dataframe(batch_data.head())
69
+
70
+ predictions = model.predict(batch_data)
71
+ predictions = transform_preds(predictions)
72
+
73
+ df_all['Churn'] = predictions
74
+
75
+ result_table = df_all[['customerid', 'Churn']]
76
+
77
+ st.text(f'๐Ÿ‘ฉ๐Ÿปโ€โš–๏ธ Predictions for 5 rows:\n {predictions[:5]}')
78
+ header('๐Ÿ’ณ Prediction by Customer Id...')
79
+
80
+ with st.form(key="Selecting Customer ID"):
81
+ option = st.selectbox(
82
+ 'Select a Custimer ID to return a predict.',
83
+ (result_table.customerid.values[:15])
84
+ )
85
+ submit_button = st.form_submit_button(label='Submit')
86
+
87
+ if submit_button:
88
+ result = result_table[result_table.customerid == option]['Churn'].values
89
+
90
+ st.text(f'๐Ÿ‘ฎ๐Ÿปโ€โ™‚๏ธ Customer ID: {option}')
91
+ st.text(f'๐Ÿ‘ฉ๐Ÿปโ€โš–๏ธ Prediction: {result}')
92
+
93
+ header('๐Ÿ‘จ๐Ÿปโ€๐ŸŽจ Prediction Visualizing...')
94
+
95
+
96
+ feature_names = batch_data.columns
97
+
98
+ feature_importance = pd.DataFrame(feature_names, columns=["feature"])
99
+ feature_importance["importance"] = model.feature_importances_
100
+ feature_importance = feature_importance.sort_values(by=["importance"], ascending=False)
101
+
102
+ fig_importance = px.bar(
103
+ feature_importance,
104
+ x='feature',
105
+ y='importance',
106
+ title='Feature Importance Plot'
107
+ )
108
+
109
+ fig_importance.update_xaxes(tickangle=23)
110
+ fig_importance.update_xaxes(title="Feature")
111
+ fig_importance.update_yaxes(title="Importance")
112
+ fig_importance.update_traces(hovertemplate='Feature: %{x} <br>Importance: %{y}')
113
+
114
+ st.plotly_chart(fig_importance)
115
+
116
+
117
+ def plot_histogram(data, x_col, title, xlabel, ylabel):
118
+
119
+ fig = go.Figure()
120
+
121
+ fig = px.histogram(
122
+ data,
123
+ x=x_col,
124
+ color="Churn",
125
+ title=title
126
+ )
127
+
128
+ fig.update_xaxes(title=xlabel)
129
+ fig.update_yaxes(title=ylabel)
130
+ fig.update_traces(hovertemplate=xlabel + ': %{x} <br>' + ylabel + ': %{y}')
131
+
132
+ return fig
133
+
134
+
135
+ st.plotly_chart(plot_histogram(df_all, 'internetservice', 'Churn rate according to internet service subscribtion', 'Internet service', 'Number of customers'))
136
+ st.plotly_chart(plot_histogram(df_all, 'streamingmovies', 'Churn rate according to streaming movies subscribtion', 'Streaming movies', 'Number of customers'))
137
+ st.plotly_chart(plot_histogram(df_all, 'streamingtv', 'Churn rate according to internet streaming tv subscribtion', 'Gender', 'Number of customers'))
138
+ st.plotly_chart(plot_histogram(df_all, 'gender', 'Churn rate according to Gender', 'Gender', 'Number of customers'))
139
+ st.plotly_chart(plot_histogram(df_all, 'totalcharges', 'Distribution of Total Charges according to Churn/Not', "Charge Value", 'Number of customers'))
140
+ st.plotly_chart(plot_histogram(df_all, 'paymentmethod', 'Amount of each Payment Method', "Payment Method", 'Total Amount'))
141
+ st.plotly_chart(plot_histogram(df_all, 'partner', 'Affect of having a partner on Churn/Not', "Have a partner", 'Number of customers'))
142
+
143
+ st.success('๐ŸŽ‰ ๐Ÿ“ˆ ๐Ÿค App Finished Successfully ๐Ÿค ๐Ÿ“ˆ ๐ŸŽ‰')