Afeezee commited on
Commit
79ff9ea
ยท
verified ยท
1 Parent(s): 9f9c3b2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +763 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,765 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import plotly.graph_objects as go
5
+ import plotly.express as px
6
+ from plotly.subplots import make_subplots
7
+ from sklearn.model_selection import train_test_split
8
+ from sklearn.preprocessing import MinMaxScaler
9
+ from sklearn.linear_model import LogisticRegression
10
+ from sklearn.ensemble import RandomForestClassifier
11
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
12
+ from sklearn.preprocessing import MinMaxScaler
13
+ from sklearn.utils import resample
14
+ import xgboost as xgb
15
+ import pickle
16
+ import io
17
+ import base64
18
+ from datetime import datetime
19
+ import warnings
20
+ warnings.filterwarnings('ignore')
21
+
22
+ # Color palette
23
+ COLORS = {
24
+ 'primary': '#14213d', # Dark blue
25
+ 'secondary': '#fca311', # Orange
26
+ 'background': '#ffffff', # White
27
+ 'light_gray': '#e5e5e5', # Light gray
28
+ 'black': '#000000' # Black
29
+ }
30
+
31
+ # Custom CSS
32
+ def apply_custom_css():
33
+ st.markdown(f"""
34
+ <style>
35
+ .main {{
36
+ background-color: {COLORS['background']};
37
+ }}
38
+
39
+ /* Force all text to be black */
40
+ .stApp, .main, .block-container {{
41
+ color: {COLORS['black']} !important;
42
+ }}
43
+
44
+ /* Override Streamlit's default text colors */
45
+ h1, h2, h3, h4, h5, h6 {{
46
+ color: {COLORS['light_gray']} !important;
47
+ }}
48
+
49
+
50
+ p, div, span {{
51
+ color: {COLORS['black']} !important;
52
+ }}
53
+
54
+ /* Input fields and labels */
55
+ .stTextInput > label, .stSelectbox > label, .stNumberInput > label {{
56
+ color: {COLORS['black']} !important;
57
+ font-weight: bold;
58
+ }}
59
+
60
+ .stTextInput input, .stSelectbox select, .stNumberInput input {{
61
+ color: {COLORS['light_gray']} !important;
62
+ }}
63
+
64
+ /* Success/Error messages */
65
+ .stSuccess, .stError, .stWarning, .stInfo {{
66
+ color: {COLORS['black']} !important;
67
+ }}
68
+
69
+ .stSuccess div, .stError div, .stWarning div, .stInfo div {{
70
+ color: {COLORS['black']} !important;
71
+ }}
72
+
73
+ /* Buttons */
74
+ .stButton > button {{
75
+ background-color: {COLORS['secondary']};
76
+ color: {COLORS['primary']};
77
+ border: none;
78
+ border-radius: 5px;
79
+ font-weight: bold;
80
+ }}
81
+
82
+ .stButton > button:hover {{
83
+ background-color: {COLORS['primary']};
84
+ color: {COLORS['secondary']};
85
+ }}
86
+
87
+ /* Metric cards */
88
+ .metric-card {{
89
+ background-color: {COLORS['light_gray']};
90
+ padding: 20px;
91
+ border-radius: 10px;
92
+ border-left: 5px solid {COLORS['secondary']};
93
+ margin: 10px 0;
94
+ color: {COLORS['black']} !important;
95
+ }}
96
+
97
+ .metric-card h2, .metric-card h3 {{
98
+ color: {COLORS['primary']} !important;
99
+ }}
100
+
101
+ /* Prediction results */
102
+ .prediction-result {{
103
+ background-color: {COLORS['primary']};
104
+ color: {COLORS['background']} !important;
105
+ padding: 15px;
106
+ border-radius: 10px;
107
+ text-align: center;
108
+ margin: 20px 0;
109
+ }}
110
+
111
+ .prediction-result h2, .prediction-result h3 {{
112
+ color: {COLORS['background']} !important;
113
+ }}
114
+
115
+ /* Header text */
116
+ .header-text {{
117
+ color: {COLORS['primary']} !important;
118
+ font-weight: bold;
119
+ }}
120
+
121
+ /* Sidebar text */
122
+ .css-1d391kg, .css-1lcbmhc {{
123
+ color: {COLORS['light_gray']} !important;
124
+ }}
125
+
126
+ /* Dataframe text */
127
+ .dataframe {{
128
+ color: {COLORS['black']} !important;
129
+ }}
130
+
131
+ /* Tab labels */
132
+ .stTabs [data-baseweb="tab-list"] button [data-testid="stMarkdownContainer"] p {{
133
+ color: {COLORS['light_gray']} !important;
134
+ }}
135
+
136
+ /* Markdown text */
137
+ .stMarkdown {{
138
+ color: {COLORS['light_gray']} !important;
139
+ }}
140
+
141
+ /* File uploader */
142
+ .stFileUploader > label {{
143
+ color: {COLORS['black']} !important;
144
+ }}
145
+
146
+ /* Multiselect */
147
+ .stMultiSelect > label {{
148
+ color: {COLORS['black']} !important;
149
+ }}
150
+
151
+ /* Slider */
152
+ .stSlider > label {{
153
+ color: {COLORS['light_gray']} !important;
154
+ }}
155
+
156
+ /* Checkbox */
157
+ .stCheckbox > label {{
158
+ color: {COLORS['black']} !important;
159
+ }}
160
+ </style>
161
+ """, unsafe_allow_html=True)
162
+
163
+ # Initialize session state
164
+ def init_session_state():
165
+ if 'logged_in' not in st.session_state:
166
+ st.session_state.logged_in = False
167
+ if 'model_trained' not in st.session_state:
168
+ st.session_state.model_trained = False
169
+ if 'model' not in st.session_state:
170
+ st.session_state.model = None
171
+ if 'scaler' not in st.session_state:
172
+ st.session_state.scaler = None
173
+ if 'data' not in st.session_state:
174
+ st.session_state.data = None
175
+ if 'model_results' not in st.session_state:
176
+ st.session_state.model_results = None
177
+
178
+ # Login page
179
+ def login_page():
180
+ st.markdown('<h1 class="header-text">๐Ÿฆ Sunrise Microfinance Bank</h1>', unsafe_allow_html=True)
181
+ st.markdown('<h2 class="header-text">Customer Churn Prediction System</h2>', unsafe_allow_html=True)
182
+
183
+ col1, col2, col3 = st.columns([1, 2, 1])
184
+
185
+ with col2:
186
+ st.markdown("### Admin Login")
187
+ username = st.text_input("Username", placeholder="Enter admin username")
188
+ password = st.text_input("Password", type="password", placeholder="Enter password")
189
+
190
+ if st.button("Login", use_container_width=True):
191
+ # Simple authentication (in production, use proper authentication)
192
+ if username == "admin" and password == "admin123":
193
+ st.session_state.logged_in = True
194
+ st.success("Login successful!")
195
+ else:
196
+ st.error("Invalid credentials. Use admin/admin123")
197
+
198
+ # Simple oversampling function to replace SMOTE
199
+ def simple_oversample(X, y, random_state=42):
200
+ """Simple oversampling by duplicating minority class samples"""
201
+ np.random.seed(random_state)
202
+
203
+ # Combine features and target
204
+ df = pd.concat([X.reset_index(drop=True), y.reset_index(drop=True)], axis=1)
205
+
206
+ # Separate majority and minority classes
207
+ majority_class = df[df[y.name] == 0]
208
+ minority_class = df[df[y.name] == 1]
209
+
210
+ # Oversample minority class
211
+ minority_upsampled = resample(minority_class,
212
+ replace=True,
213
+ n_samples=len(majority_class),
214
+ random_state=random_state)
215
+
216
+ # Combine majority and upsampled minority
217
+ df_upsampled = pd.concat([majority_class, minority_upsampled])
218
+
219
+ # Separate features and target
220
+ X_resampled = df_upsampled.drop(y.name, axis=1)
221
+ y_resampled = df_upsampled[y.name]
222
+
223
+ return X_resampled, y_resampled
224
+
225
+ # Data preprocessing function
226
+ def preprocess_data(df):
227
+ # Drop non-predictive columns
228
+ if 'CustomerId' in df.columns:
229
+ df = df.drop(['CustomerId'], axis=1)
230
+ if 'Surname' in df.columns:
231
+ df = df.drop(['Surname'], axis=1)
232
+
233
+ # Feature encoding
234
+ df['Gender'] = df['Gender'].map({'Male': 0, 'Female': 1})
235
+ df['Account Activity'] = df['Account Activity'].map({'Active': 0, 'Dormant': 1})
236
+ df['Repayment Timeliness'] = df['Repayment Timeliness'].map({'On-time': 0, 'Late': 1})
237
+
238
+ df['Account Balance Trend'] = df['Account Balance Trend'].map({
239
+ 'Decreasing': 0,
240
+ 'Stable': 1,
241
+ 'Increasing': 2
242
+ })
243
+
244
+ # Convert binary columns to int
245
+ binary_columns = ['Use of Savings Products', 'Use of Loan Products', 'Participation in Group Lending']
246
+ for col in binary_columns:
247
+ if col in df.columns:
248
+ df[col] = df[col].astype(int)
249
+
250
+ # One-hot encoding for categorical variables
251
+ categorical_columns = ['Marital Status', 'Education Level', 'Loan History', 'Use of Digital Banking']
252
+ for col in categorical_columns:
253
+ if col in df.columns:
254
+ df = pd.get_dummies(df, columns=[col], prefix=col.replace(' ', '_'))
255
+
256
+ return df
257
+
258
+ # Dashboard page
259
+ def dashboard_page():
260
+ st.markdown('<h1 class="header-text">๐Ÿ“Š Super Admin Dashboard</h1>', unsafe_allow_html=True)
261
+
262
+ if st.session_state.data is not None:
263
+ df = st.session_state.data
264
+
265
+ # Key metrics
266
+ col1, col2, col3, col4 = st.columns(4)
267
+
268
+ with col1:
269
+ st.markdown(f"""
270
+ <div class="metric-card">
271
+ <h3>Total Customers</h3>
272
+ <h2>{len(df)}</h2>
273
+ </div>
274
+ """, unsafe_allow_html=True)
275
+
276
+ with col2:
277
+ churn_rate = df['Exited'].mean() * 100 if 'Exited' in df.columns else 0
278
+ st.markdown(f"""
279
+ <div class="metric-card">
280
+ <h3>Churn Rate</h3>
281
+ <h2>{churn_rate:.1f}%</h2>
282
+ </div>
283
+ """, unsafe_allow_html=True)
284
+
285
+ with col3:
286
+ active_customers = len(df) - df['Exited'].sum() if 'Exited' in df.columns else len(df)
287
+ st.markdown(f"""
288
+ <div class="metric-card">
289
+ <h3>Active Customers</h3>
290
+ <h2>{active_customers}</h2>
291
+ </div>
292
+ """, unsafe_allow_html=True)
293
+
294
+ with col4:
295
+ avg_age = df['Age'].mean() if 'Age' in df.columns else 0
296
+ st.markdown(f"""
297
+ <div class="metric-card">
298
+ <h3>Average Age</h3>
299
+ <h2>{avg_age:.1f}</h2>
300
+ </div>
301
+ """, unsafe_allow_html=True)
302
+
303
+ # Charts
304
+ st.markdown("### ๐Ÿ“ˆ Customer Analytics")
305
+
306
+ if 'Exited' in df.columns:
307
+ col1, col2 = st.columns(2)
308
+
309
+ with col1:
310
+ # Churn distribution
311
+ churn_counts = df['Exited'].value_counts()
312
+ fig = go.Figure(data=[go.Pie(
313
+ labels=['Retained', 'Churned'],
314
+ values=[churn_counts[0], churn_counts[1]],
315
+ marker_colors=[COLORS['secondary'], COLORS['primary']]
316
+ )])
317
+ fig.update_layout(title="Customer Retention vs Churn", title_x=0.5)
318
+ st.plotly_chart(fig, use_container_width=True)
319
+
320
+ with col2:
321
+ # Age distribution by churn
322
+ fig = px.histogram(df, x='Age', color='Exited', nbins=20,
323
+ title="Age Distribution by Churn Status",
324
+ color_discrete_map={0: COLORS['secondary'], 1: COLORS['primary']})
325
+ st.plotly_chart(fig, use_container_width=True)
326
+
327
+ else:
328
+ st.info("Please upload data first to see dashboard metrics.")
329
+
330
+ # Upload data page
331
+ def upload_data_page():
332
+ st.markdown('<h1 class="header-text">๐Ÿ“ Upload Customer Data</h1>', unsafe_allow_html=True)
333
+
334
+ uploaded_file = st.file_uploader(
335
+ "Choose a CSV file",
336
+ type=['csv'],
337
+ help="Upload your customer dataset in CSV format"
338
+ )
339
+
340
+ if uploaded_file is not None:
341
+ try:
342
+ df = pd.read_csv(uploaded_file)
343
+ st.success(f"Data uploaded successfully! {len(df)} records loaded.")
344
+
345
+ # Display data info
346
+ st.markdown("### Data Preview")
347
+ st.dataframe(df.head(10))
348
+
349
+ st.markdown("### Data Summary")
350
+ col1, col2 = st.columns(2)
351
+
352
+ with col1:
353
+ st.markdown("**Dataset Shape:**")
354
+ st.write(f"Rows: {df.shape[0]}")
355
+ st.write(f"Columns: {df.shape[1]}")
356
+
357
+ with col2:
358
+ st.markdown("**Missing Values:**")
359
+ missing_values = df.isnull().sum().sum()
360
+ st.write(f"Total: {missing_values}")
361
+
362
+ # Store data in session state
363
+ st.session_state.data = df
364
+
365
+ if st.button("Process Data", use_container_width=True):
366
+ with st.spinner("Processing data..."):
367
+ processed_df = preprocess_data(df.copy())
368
+ st.session_state.processed_data = processed_df
369
+ st.success("Data processed successfully!")
370
+ st.markdown("### Processed Data Preview")
371
+ st.dataframe(processed_df.head())
372
+
373
+ except Exception as e:
374
+ st.error(f"Error loading data: {str(e)}")
375
+
376
+ # Sample data option
377
+ st.markdown("### Or Use Sample Data")
378
+ if st.button("Load Sample Data"):
379
+ # Create sample data based on your description
380
+ np.random.seed(42)
381
+ n_samples = 1000
382
+
383
+ sample_data = {
384
+ 'CustomerId': [f'SMB{15565700 + i + 1}' for i in range(n_samples)],
385
+ 'Surname': ['Abdullahi', 'Bello', 'Adesina', 'Sule', 'Nwachukwu'] * (n_samples // 5),
386
+ 'Age': np.random.randint(18, 92, n_samples),
387
+ 'Gender': np.random.choice(['Male', 'Female'], n_samples),
388
+ 'Marital Status': np.random.choice(['Single', 'Married', 'Divorced'], n_samples),
389
+ 'Education Level': np.random.choice(['None', 'Primary', 'Secondary', 'Tertiary'], n_samples),
390
+ 'Account Balance Trend': np.random.choice(['Decreasing', 'Stable', 'Increasing'], n_samples),
391
+ 'Loan History': np.random.choice(['Active', 'Cleared', 'Defaulted'], n_samples),
392
+ 'Frequency of Deposits/Withdrawals': np.random.randint(0, 15, n_samples),
393
+ 'Average Transaction Value': np.random.uniform(1000, 50000, n_samples),
394
+ 'Account Activity': np.random.choice(['Active', 'Dormant'], n_samples),
395
+ 'Use of Savings Products': np.random.choice([0, 1], n_samples),
396
+ 'Use of Loan Products': np.random.choice([0, 1], n_samples),
397
+ 'Use of Digital Banking': np.random.choice(['USSD', 'App', 'Both', 'None'], n_samples),
398
+ 'Participation in Group Lending': np.random.choice([0, 1], n_samples),
399
+ 'Tenure': np.random.randint(0, 10, n_samples),
400
+ 'Number of Complaints Logged': np.random.randint(0, 5, n_samples),
401
+ 'Response Time to Complaints': np.random.randint(0, 15, n_samples),
402
+ 'Customer Support Interactions': np.random.randint(0, 10, n_samples),
403
+ 'Repayment Timeliness': np.random.choice(['On-time', 'Late'], n_samples),
404
+ 'Overdue Loan Frequency': np.random.randint(0, 5, n_samples),
405
+ 'Penalties Paid': np.random.uniform(0, 10000, n_samples),
406
+ 'Exited': np.random.choice([0, 1], n_samples, p=[0.8, 0.2])
407
+ }
408
+
409
+ df = pd.DataFrame(sample_data)
410
+ st.session_state.data = df
411
+ st.success("Sample data loaded successfully!")
412
+ st.dataframe(df.head())
413
+
414
+ # Model training page
415
+ def model_training_page():
416
+ st.markdown('<h1 class="header-text">๐Ÿค– Model Training</h1>', unsafe_allow_html=True)
417
+
418
+ if st.session_state.data is None:
419
+ st.warning("Please upload data first.")
420
+ return
421
+
422
+ df = st.session_state.data.copy()
423
+
424
+ st.markdown("### Training Configuration")
425
+
426
+ col1, col2 = st.columns(2)
427
+ with col1:
428
+ test_size = st.slider("Test Size", 0.1, 0.5, 0.3, 0.05)
429
+ use_oversampling = st.checkbox("Use Oversampling for Imbalanced Data", value=True)
430
+
431
+ with col2:
432
+ random_state = st.number_input("Random State", value=42)
433
+
434
+ selected_features = st.multiselect(
435
+ "Select Features for Training",
436
+ ['Age', 'Gender', 'Tenure', 'Frequency of Deposits/Withdrawals',
437
+ 'Repayment Timeliness', 'Account Activity', 'Account Balance Trend'],
438
+ default=['Age', 'Gender', 'Tenure', 'Frequency of Deposits/Withdrawals',
439
+ 'Repayment Timeliness', 'Account Activity', 'Account Balance Trend']
440
+ )
441
+
442
+ if st.button("Train Models", use_container_width=True):
443
+ if not selected_features:
444
+ st.error("Please select at least one feature.")
445
+ return
446
+
447
+ with st.spinner("Training models..."):
448
+ # Preprocess data
449
+ processed_df = preprocess_data(df)
450
+
451
+ # Prepare features and target
452
+ available_features = [f for f in selected_features if f in processed_df.columns]
453
+ X = processed_df[available_features]
454
+ y = processed_df['Exited']
455
+
456
+ # Handle class imbalance with SMOTE
457
+ if use_oversampling:
458
+ X_resampled, y_resampled = simple_oversample(X, y, random_state=random_state)
459
+ else:
460
+ X_resampled, y_resampled = X, y
461
+ # Feature scaling
462
+ scaler = MinMaxScaler()
463
+ X_scaled = scaler.fit_transform(X_resampled)
464
+ X_scaled = pd.DataFrame(X_scaled, columns=X.columns)
465
+
466
+ # Split data
467
+ X_train, X_test, y_train, y_test = train_test_split(
468
+ X_scaled, y_resampled, test_size=test_size, random_state=random_state
469
+ )
470
+
471
+ # Train models
472
+ models = {
473
+ 'Logistic Regression': LogisticRegression(random_state=random_state),
474
+ 'Random Forest': RandomForestClassifier(random_state=random_state, n_estimators=100),
475
+ 'XGBoost': xgb.XGBClassifier(random_state=random_state, use_label_encoder=False, eval_metric='logloss')
476
+ }
477
+
478
+ results = {}
479
+ trained_models = {}
480
+
481
+ for name, model in models.items():
482
+ model.fit(X_train, y_train)
483
+ y_pred = model.predict(X_test)
484
+ y_pred_proba = model.predict_proba(X_test)[:, 1]
485
+
486
+ results[name] = {
487
+ 'Accuracy': accuracy_score(y_test, y_pred),
488
+ 'Precision': precision_score(y_test, y_pred),
489
+ 'Recall': recall_score(y_test, y_pred),
490
+ 'F1-Score': f1_score(y_test, y_pred),
491
+ 'ROC-AUC': roc_auc_score(y_test, y_pred_proba)
492
+ }
493
+ trained_models[name] = model
494
+
495
+ # Select best model
496
+ best_model_name = max(results, key=lambda x: results[x]['F1-Score'])
497
+ best_model = trained_models[best_model_name]
498
+
499
+ # Store in session state
500
+ st.session_state.model = best_model
501
+ st.session_state.scaler = scaler
502
+ st.session_state.model_results = results
503
+ st.session_state.best_model_name = best_model_name
504
+ st.session_state.feature_names = X.columns.tolist()
505
+ st.session_state.model_trained = True
506
+ st.session_state.X_test = X_test
507
+ st.session_state.y_test = y_test
508
+
509
+ st.success(f"Models trained successfully! Best model: {best_model_name}")
510
+
511
+ # Display results
512
+ st.markdown("### Model Performance")
513
+ results_df = pd.DataFrame(results).T
514
+ st.dataframe(results_df.round(4))
515
+
516
+ # Feature importance
517
+ if best_model_name in ['Random Forest', 'XGBoost']:
518
+ st.markdown("### Feature Importance")
519
+ importance_df = pd.DataFrame({
520
+ 'Feature': X.columns,
521
+ 'Importance': best_model.feature_importances_
522
+ }).sort_values('Importance', ascending=False)
523
+
524
+ fig = px.bar(importance_df, x='Importance', y='Feature',
525
+ orientation='h', title="Feature Importance",
526
+ color_discrete_sequence=[COLORS['secondary']])
527
+ st.plotly_chart(fig, use_container_width=True)
528
+
529
+ # Prediction page
530
+ def prediction_page():
531
+ st.markdown('<h1 class="header-text">๐Ÿ”ฎ Customer Churn Prediction</h1>', unsafe_allow_html=True)
532
+
533
+ if not st.session_state.model_trained:
534
+ st.warning("Please train a model first.")
535
+ return
536
+
537
+ tab1, tab2 = st.tabs(["Single Prediction", "Bulk Prediction"])
538
+
539
+ with tab1:
540
+ st.markdown("### Single Customer Prediction")
541
+
542
+ col1, col2 = st.columns(2)
543
+
544
+ with col1:
545
+ age = st.number_input("Age", 18, 100, 35)
546
+ gender = st.selectbox("Gender", ["Male", "Female"])
547
+ tenure = st.number_input("Tenure (years)", 0, 10, 2)
548
+ freq_deposits = st.number_input("Frequency of Deposits/Withdrawals", 0, 14, 5)
549
+
550
+ with col2:
551
+ repayment = st.selectbox("Repayment Timeliness", ["On-time", "Late"])
552
+ account_activity = st.selectbox("Account Activity", ["Active", "Dormant"])
553
+ balance_trend = st.selectbox("Account Balance Trend", ["Decreasing", "Stable", "Increasing"])
554
+
555
+ if st.button("Predict Churn", use_container_width=True):
556
+ # Prepare input data
557
+ input_data = pd.DataFrame({
558
+ 'Age': [age / 100], # Normalized
559
+ 'Gender': [1 if gender == "Female" else 0],
560
+ 'Tenure': [tenure / 10], # Normalized
561
+ 'Frequency of Deposits/Withdrawals': [freq_deposits / 14], # Normalized
562
+ 'Repayment Timeliness': [1 if repayment == "Late" else 0],
563
+ 'Account Activity': [1 if account_activity == "Dormant" else 0],
564
+ 'Account Balance Trend': [0 if balance_trend == "Decreasing" else 1 if balance_trend == "Stable" else 2]
565
+ })
566
+
567
+ # Make prediction
568
+ prediction = st.session_state.model.predict(input_data)[0]
569
+ probability = st.session_state.model.predict_proba(input_data)[0]
570
+
571
+ # Display result
572
+ if prediction == 1:
573
+ st.markdown(f"""
574
+ <div class="prediction-result" style="background-color: {COLORS['primary']};">
575
+ <h2>โš ๏ธ HIGH CHURN RISK</h2>
576
+ <h3>Probability: {probability[1]:.1%}</h3>
577
+ </div>
578
+ """, unsafe_allow_html=True)
579
+ else:
580
+ st.markdown(f"""
581
+ <div class="prediction-result" style="background-color: {COLORS['secondary']};">
582
+ <h2>โœ… LOW CHURN RISK</h2>
583
+ <h3>Probability: {probability[0]:.1%}</h3>
584
+ </div>
585
+ """, unsafe_allow_html=True)
586
+
587
+ # Key factors
588
+ st.markdown("### Key Risk Factors")
589
+ risk_factors = []
590
+ if age < 30 or age > 70:
591
+ risk_factors.append("Age group has higher churn tendency")
592
+ if account_activity == "Dormant":
593
+ risk_factors.append("Dormant account increases churn risk")
594
+ if repayment == "Late":
595
+ risk_factors.append("Late repayments indicate financial stress")
596
+ if freq_deposits < 3:
597
+ risk_factors.append("Low transaction frequency")
598
+ if tenure < 2:
599
+ risk_factors.append("Short tenure with bank")
600
+
601
+ if risk_factors:
602
+ for factor in risk_factors:
603
+ st.write(f"โ€ข {factor}")
604
+ else:
605
+ st.write("โ€ข Customer profile shows good retention indicators")
606
+
607
+ with tab2:
608
+ st.markdown("### Bulk Prediction")
609
+
610
+ uploaded_file = st.file_uploader(
611
+ "Upload CSV file for bulk prediction",
612
+ type=['csv'],
613
+ help="Upload a CSV file with customer data"
614
+ )
615
+
616
+ if uploaded_file is not None:
617
+ try:
618
+ df = pd.read_csv(uploaded_file)
619
+ st.write(f"Loaded {len(df)} records")
620
+
621
+ if st.button("Run Bulk Prediction"):
622
+ # Process and predict
623
+ processed_df = preprocess_data(df.copy())
624
+
625
+ # Ensure all required features are present
626
+ required_features = st.session_state.feature_names
627
+ available_features = [f for f in required_features if f in processed_df.columns]
628
+
629
+ if len(available_features) == len(required_features):
630
+ X = processed_df[available_features]
631
+ X_scaled = st.session_state.scaler.transform(X)
632
+
633
+ predictions = st.session_state.model.predict(X_scaled)
634
+ probabilities = st.session_state.model.predict_proba(X_scaled)[:, 1]
635
+
636
+ # Add results to dataframe
637
+ results_df = df.copy()
638
+ results_df['Churn_Prediction'] = ['High Risk' if p == 1 else 'Low Risk' for p in predictions]
639
+ results_df['Churn_Probability'] = probabilities
640
+
641
+ st.markdown("### Prediction Results")
642
+ st.dataframe(results_df)
643
+
644
+ # Summary
645
+ high_risk_count = sum(predictions)
646
+ st.markdown(f"**Summary:** {high_risk_count} out of {len(df)} customers are at high risk of churn ({high_risk_count/len(df)*100:.1f}%)")
647
+
648
+ # Download results
649
+ csv = results_df.to_csv(index=False)
650
+ st.download_button(
651
+ "Download Results",
652
+ csv,
653
+ "churn_predictions.csv",
654
+ "text/csv"
655
+ )
656
+ else:
657
+ st.error("Missing required features in uploaded data")
658
+
659
+ except Exception as e:
660
+ st.error(f"Error processing file: {str(e)}")
661
+
662
+ # Reports page
663
+ def reports_page():
664
+ st.markdown('<h1 class="header-text">๐Ÿ“Š Model Reports</h1>', unsafe_allow_html=True)
665
+
666
+ if not st.session_state.model_trained:
667
+ st.warning("Please train a model first to view reports.")
668
+ return
669
+
670
+ # Model performance summary
671
+ st.markdown("### Model Performance Summary")
672
+ results_df = pd.DataFrame(st.session_state.model_results).T
673
+ st.dataframe(results_df.round(4))
674
+
675
+ # Best model info
676
+ st.info(f"Best Model: {st.session_state.best_model_name}")
677
+
678
+ col1, col2 = st.columns(2)
679
+
680
+ with col1:
681
+ # Feature importance
682
+ if st.session_state.best_model_name in ['Random Forest', 'XGBoost']:
683
+ st.markdown("### Feature Importance")
684
+ importance_df = pd.DataFrame({
685
+ 'Feature': st.session_state.feature_names,
686
+ 'Importance': st.session_state.model.feature_importances_
687
+ }).sort_values('Importance', ascending=False)
688
+
689
+ fig = px.bar(importance_df, x='Importance', y='Feature',
690
+ orientation='h',
691
+ color_discrete_sequence=[COLORS['secondary']])
692
+ fig.update_layout(height=400)
693
+ st.plotly_chart(fig, use_container_width=True)
694
+
695
+ with col2:
696
+ # Confusion matrix
697
+ st.markdown("### Confusion Matrix")
698
+ if hasattr(st.session_state, 'X_test') and hasattr(st.session_state, 'y_test'):
699
+ y_pred = st.session_state.model.predict(st.session_state.X_test)
700
+ cm = confusion_matrix(st.session_state.y_test, y_pred)
701
+
702
+ fig = px.imshow(cm,
703
+ text_auto=True,
704
+ aspect="auto",
705
+ color_continuous_scale='Blues',
706
+ labels=dict(x="Predicted", y="Actual"))
707
+ fig.update_layout(height=400)
708
+ st.plotly_chart(fig, use_container_width=True)
709
+
710
+ # Recommendations
711
+ st.markdown("### Business Recommendations")
712
+ recommendations = [
713
+ "Focus retention efforts on customers with short tenure and low transaction frequency",
714
+ "Implement proactive engagement for dormant accounts",
715
+ "Develop targeted programs for high-risk age groups",
716
+ "Improve digital banking adoption to increase engagement",
717
+ "Monitor and address late payment patterns early",
718
+ "Create loyalty programs for long-term customers"
719
+ ]
720
+
721
+ for i, rec in enumerate(recommendations, 1):
722
+ st.write(f"{i}. {rec}")
723
+
724
+ # Main app
725
+ def main():
726
+ st.set_page_config(
727
+ page_title="Customer Churn Prediction",
728
+ page_icon="๐Ÿฆ",
729
+ layout="wide",
730
+ initial_sidebar_state="expanded"
731
+ )
732
+
733
+ apply_custom_css()
734
+ init_session_state()
735
+
736
+ if not st.session_state.logged_in:
737
+ login_page()
738
+ return
739
+
740
+ # Sidebar navigation
741
+ st.sidebar.markdown("### Navigation")
742
+ pages = {
743
+ "๐Ÿ  Dashboard": dashboard_page,
744
+ "๐Ÿ“ Upload Data": upload_data_page,
745
+ "๐Ÿค– Train Model": model_training_page,
746
+ "๐Ÿ”ฎ Predictions": prediction_page,
747
+ "๐Ÿ“Š Reports": reports_page
748
+ }
749
+
750
+ selected_page = st.sidebar.selectbox("Choose a page", list(pages.keys()))
751
+
752
+ # Logout button
753
+ if st.sidebar.button("Logout"):
754
+ st.session_state.logged_in = False
755
+
756
+ # Display selected page
757
+ pages[selected_page]()
758
+
759
+ # Footer
760
+ st.sidebar.markdown("---")
761
+ st.sidebar.markdown("**Sunrise Microfinance Bank**")
762
+ st.sidebar.markdown("Customer Churn Prediction System")
763
 
764
+ if __name__ == "__main__":
765
+ main()