3v324v23 commited on
Commit
e55cbec
·
1 Parent(s): d6d075e

revise plot using plotly express

Browse files
Files changed (5) hide show
  1. app.py +3 -0
  2. csv/BankPCA.csv +0 -0
  3. csv/Cluster.csv +0 -0
  4. eda.py +12 -21
  5. prediction.py +22 -6
app.py CHANGED
@@ -14,6 +14,7 @@ with st.sidebar:
14
  [
15
  "Distribution",
16
  "Classification",
 
17
  ],
18
  icons=["bar-chart", "link-45deg", "code-square"],
19
  menu_icon="cast",
@@ -24,3 +25,5 @@ if selected == "Distribution":
24
  eda.distribution()
25
  elif selected == "Classification":
26
  prediction.predict()
 
 
 
14
  [
15
  "Distribution",
16
  "Classification",
17
+ "Cluster"
18
  ],
19
  icons=["bar-chart", "link-45deg", "code-square"],
20
  menu_icon="cast",
 
25
  eda.distribution()
26
  elif selected == "Classification":
27
  prediction.predict()
28
+ elif selected == "Cluster":
29
+ prediction.cluster()
csv/BankPCA.csv ADDED
The diff for this file is too large to render. See raw diff
 
csv/Cluster.csv ADDED
The diff for this file is too large to render. See raw diff
 
eda.py CHANGED
@@ -1,8 +1,6 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- import seaborn as sns
5
- import matplotlib.pyplot as plt
6
  import plotly.express as px
7
 
8
  bank_df = pd.read_csv('./csv/BankChurners.csv')
@@ -10,8 +8,6 @@ bank_df.drop(columns=["CLIENTNUM",
10
  "Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_1",
11
  "Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_2"], inplace=True)
12
 
13
- sns.set(style="whitegrid")
14
- palette=["teal", "darkblue"]
15
  num_col = bank_df.select_dtypes(include=np.number).columns.tolist()
16
  cat_col = bank_df.select_dtypes(include=object).columns.tolist()
17
  cat_col.remove("Attrition_Flag")
@@ -24,20 +20,17 @@ st.set_page_config(
24
 
25
  def distribution():
26
  # distribution plot
27
-
28
  st.header("Data Distribution")
29
 
30
  attr_plot('Attrition_Flag')
31
 
32
  col1, col2 = st.columns(2)
33
 
34
- numerik = col1.selectbox(label="Select Features",
35
- options=num_col)
36
 
37
  hist_plot(numerik, col1)
38
 
39
- kategorik = col2.selectbox(label="Select Features",
40
- options=cat_col)
41
 
42
  count_plot(kategorik, col2)
43
 
@@ -46,19 +39,17 @@ def distribution():
46
  ''')
47
 
48
  def attr_plot(column):
49
- fig = plt.figure(figsize=(15,5))
50
- sns.countplot(data=bank_df, y=column, palette=palette, alpha=0.7)
51
- st.pyplot(fig)
52
 
53
  def hist_plot(column, loc):
54
- fig = plt.figure(figsize=(15,6))
55
- sns.histplot(data=bank_df, x=column, kde=True, bins=50, palette=palette, hue="Attrition_Flag")
56
- loc.pyplot(fig)
57
-
58
- def count_plot(column,loc):
59
- fig = plt.figure(figsize=(15,6))
60
- sns.countplot(data=bank_df, y=column, palette=palette, hue="Attrition_Flag", order=bank_df[column].value_counts().index, alpha=0.7)
61
- loc.pyplot(fig)
62
-
63
  if __name__ == "__main__":
64
  distribution()
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
 
 
4
  import plotly.express as px
5
 
6
  bank_df = pd.read_csv('./csv/BankChurners.csv')
 
8
  "Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_1",
9
  "Naive_Bayes_Classifier_Attrition_Flag_Card_Category_Contacts_Count_12_mon_Dependent_count_Education_Level_Months_Inactive_12_mon_2"], inplace=True)
10
 
 
 
11
  num_col = bank_df.select_dtypes(include=np.number).columns.tolist()
12
  cat_col = bank_df.select_dtypes(include=object).columns.tolist()
13
  cat_col.remove("Attrition_Flag")
 
20
 
21
  def distribution():
22
  # distribution plot
 
23
  st.header("Data Distribution")
24
 
25
  attr_plot('Attrition_Flag')
26
 
27
  col1, col2 = st.columns(2)
28
 
29
+ numerik = col1.selectbox(label="Select Features", options=num_col)
 
30
 
31
  hist_plot(numerik, col1)
32
 
33
+ kategorik = col2.selectbox(label="Select Features", options=cat_col)
 
34
 
35
  count_plot(kategorik, col2)
36
 
 
39
  ''')
40
 
41
  def attr_plot(column):
42
+ fig = px.histogram(bank_df, y=column, color="Attrition_Flag", title=f'Distribution of {column}')
43
+ fig.update_layout(width=1200)
44
+ st.plotly_chart(fig, use_container_width=True)
45
 
46
  def hist_plot(column, loc):
47
+ fig = px.histogram(bank_df, x=column, color="Attrition_Flag", title=f'Histogram of {column}')
48
+ loc.plotly_chart(fig)
49
+
50
+ def count_plot(column, loc):
51
+ fig = px.bar(bank_df, y=column, color="Attrition_Flag", title=f'Count Plot of {column}', orientation='h')
52
+ loc.plotly_chart(fig)
53
+
 
 
54
  if __name__ == "__main__":
55
  distribution()
prediction.py CHANGED
@@ -1,11 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import numpy as np
4
  import pickle
5
- import matplotlib.pyplot as plt
6
- from PIL import Image
7
- from urllib import request
8
- from io import BytesIO
9
 
10
 
11
  def predict():
@@ -103,6 +99,26 @@ def predict():
103
  """
104
  st.markdown(result_html.format(pred_inf=pred_inf, cluster_inf=cluster_inf, color=color, step=recommendation), unsafe_allow_html=True)
105
 
106
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  if __name__ == "__main__":
108
  predict()
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import pickle
4
+ import plotly.express as px
 
 
 
5
 
6
 
7
  def predict():
 
99
  """
100
  st.markdown(result_html.format(pred_inf=pred_inf, cluster_inf=cluster_inf, color=color, step=recommendation), unsafe_allow_html=True)
101
 
102
+ def cluster():
103
+ clusters = pd.read_csv('./csv/Cluster.csv')
104
+ bank_df_pca = pd.read_csv('./csv/BankPCA.csv')
105
+
106
+ colors = {0: 'navy', 1: 'teal'}
107
+ names = {0: 'High Spent Amount (>4K), High Usage Frequency',
108
+ 1: 'Low Spent Amount (<4K), Low Usage Frequency'}
109
+
110
+ bank_df_pca['color'] = bank_df_pca['label'].map(colors)
111
+ bank_df_pca['name'] = bank_df_pca['label'].map(names)
112
+
113
+ fig = px.scatter(bank_df_pca, x='x', y='y', color='name', hover_name='name',
114
+ title='Churn Customer Clustering', width=800, height=400, )
115
+
116
+ fig.update_traces(marker=dict(size=5))
117
+ fig.update_layout(showlegend=True)
118
+
119
+ fig.update_layout(height=600)
120
+ st.plotly_chart(fig, use_container_width=True)
121
+
122
+
123
  if __name__ == "__main__":
124
  predict()