alzami commited on
Commit
8dfedb7
ยท
verified ยท
1 Parent(s): dd3075d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +358 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,360 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Materi Dr. Eng. Farrikh Alzami, M.Kom - Universitas Dian Nuswantoro
3
+ '''
4
  import streamlit as st
5
 
6
+ # Page configuration - MUST be first Streamlit command
7
+ st.set_page_config(
8
+ page_title="Income Prediction App - Materi Dr.Eng. Farrikh Alzami, M.Kom",
9
+ page_icon="๐Ÿ’ฐ",
10
+ layout="wide",
11
+ initial_sidebar_state="collapsed"
12
+ )
13
+
14
+ import pandas as pd
15
+ import numpy as np
16
+ import joblib
17
+ import plotly.express as px
18
+ import plotly.graph_objects as go
19
+ from datetime import datetime
20
+ import json
21
+
22
+ # Load model components
23
+ @st.cache_resource
24
+ def load_model():
25
+ """Load the trained model components"""
26
+ try:
27
+ components = joblib.load('income_prediction_components.joblib')
28
+ return components
29
+ except FileNotFoundError:
30
+ st.error("Model file 'income_prediction_components.joblib' not found!")
31
+ st.stop()
32
+ except Exception as e:
33
+ st.error(f"Error loading model: {str(e)}")
34
+ st.stop()
35
+
36
+ def predict_income(data, model_components):
37
+ """Make income predictions using the trained model"""
38
+ # Convert to DataFrame if needed
39
+ if isinstance(data, dict):
40
+ df = pd.DataFrame([data])
41
+ else:
42
+ df = data.copy()
43
+
44
+ # Get components
45
+ model = model_components['model']
46
+ encoding_maps = model_components['encoding_maps']
47
+ feature_names = model_components['feature_names']
48
+
49
+ # Apply encodings to categorical columns
50
+ for column in df.columns:
51
+ if column in encoding_maps and column != 'income':
52
+ df[column] = df[column].map(encoding_maps[column])
53
+
54
+ # Ensure we only use features that the model was trained on
55
+ df_for_pred = df[feature_names].copy()
56
+
57
+ # Make prediction
58
+ prediction = model.predict(df_for_pred)[0]
59
+ probabilities = model.predict_proba(df_for_pred)[0]
60
+
61
+ # Get income label
62
+ income_map_inverse = {v: k for k, v in encoding_maps['income'].items()}
63
+ prediction_label = income_map_inverse[prediction]
64
+
65
+ return {
66
+ 'prediction': int(prediction),
67
+ 'prediction_label': prediction_label,
68
+ 'probability': float(probabilities[prediction]),
69
+ 'probabilities': probabilities.tolist()
70
+ }
71
+
72
+ def validate_inputs(data):
73
+ """Validate input data"""
74
+ errors = []
75
+
76
+ # Age validation
77
+ if data['age'] < 17 or data['age'] > 90:
78
+ errors.append("Age should be between 17 and 90")
79
+
80
+ # Education number validation
81
+ if data['education_num'] < 1 or data['education_num'] > 16:
82
+ errors.append("Education number should be between 1 and 16")
83
+
84
+ # Hours per week validation
85
+ if data['hours_per_week'] < 1 or data['hours_per_week'] > 99:
86
+ errors.append("Hours per week should be between 1 and 99")
87
+
88
+ # Capital gain/loss validation
89
+ if data['capital_gain'] < 0 or data['capital_gain'] > 99999:
90
+ errors.append("Capital gain should be between 0 and 99999")
91
+
92
+ if data['capital_loss'] < 0 or data['capital_loss'] > 4356:
93
+ errors.append("Capital loss should be between 0 and 4356")
94
+
95
+ # Final weight validation
96
+ if data['fnlwgt'] < 12285 or data['fnlwgt'] > 1484705:
97
+ errors.append("Final weight should be between 12285 and 1484705")
98
+
99
+ return errors
100
+
101
+ def export_prediction(data, result):
102
+ """Export prediction result to JSON"""
103
+ export_data = {
104
+ 'timestamp': datetime.now().isoformat(),
105
+ 'input_data': data,
106
+ 'prediction': {
107
+ 'class': result['prediction_label'],
108
+ 'confidence': result['probability'],
109
+ 'raw_prediction': result['prediction']
110
+ }
111
+ }
112
+ return json.dumps(export_data, indent=2)
113
+
114
+ def reset_session_state():
115
+ """Reset all input values to default"""
116
+ keys_to_reset = [
117
+ 'age', 'workclass', 'fnlwgt', 'education_num', 'marital_status',
118
+ 'occupation', 'relationship', 'race', 'sex', 'capital_gain',
119
+ 'capital_loss', 'hours_per_week', 'native_country'
120
+ ]
121
+ for key in keys_to_reset:
122
+ if key in st.session_state:
123
+ del st.session_state[key]
124
+
125
+ # Load model
126
+ model_components = load_model()
127
+
128
+ # Define mappings (from the original notebook)
129
+ workclass_options = ['State-gov', 'Self-emp-not-inc', 'Private', 'Federal-gov',
130
+ 'Local-gov', 'Self-emp-inc', 'Without-pay', 'Never-worked']
131
+
132
+ marital_status_options = ['Never-married', 'Married-civ-spouse', 'Divorced',
133
+ 'Married-spouse-absent', 'Separated', 'Married-AF-spouse', 'Widowed']
134
+
135
+ occupation_options = ['Adm-clerical', 'Exec-managerial', 'Handlers-cleaners', 'Prof-specialty',
136
+ 'Other-service', 'Sales', 'Craft-repair', 'Transport-moving',
137
+ 'Farming-fishing', 'Machine-op-inspct', 'Tech-support',
138
+ 'Protective-serv', 'Armed-Forces', 'Priv-house-serv']
139
+
140
+ relationship_options = ['Not-in-family', 'Husband', 'Wife', 'Own-child', 'Unmarried', 'Other-relative']
141
+
142
+ race_options = ['White', 'Black', 'Asian-Pac-Islander', 'Amer-Indian-Eskimo', 'Other']
143
+
144
+ sex_options = ['Male', 'Female']
145
+
146
+ native_country_options = ['United-States', 'Cuba', 'Jamaica', 'India', 'Mexico', 'South',
147
+ 'Puerto-Rico', 'Honduras', 'England', 'Canada', 'Germany', 'Iran',
148
+ 'Philippines', 'Italy', 'Poland', 'Columbia', 'Cambodia', 'Thailand', 'Ecuador',
149
+ 'Laos', 'Taiwan', 'Haiti', 'Portugal', 'Dominican-Republic', 'El-Salvador',
150
+ 'France', 'Guatemala', 'China', 'Japan', 'Yugoslavia', 'Peru',
151
+ 'Outlying-US(Guam-USVI-etc)', 'Scotland', 'Trinadad&Tobago', 'Greece',
152
+ 'Nicaragua', 'Vietnam', 'Hong', 'Ireland', 'Hungary', 'Holand-Netherlands']
153
+
154
+ # Main app
155
+ st.title("๐Ÿ’ฐ Income Prediction App - Dr. Eng. Farrikh Alzami, M.Kom")
156
+ st.markdown("Predict whether income exceeds $50K/year based on demographic data")
157
+
158
+ # Create two columns for layout
159
+ col1, col2 = st.columns([2, 1])
160
+
161
+ with col1:
162
+ st.subheader("๐Ÿ“ Input Features")
163
+
164
+ # Create form for inputs
165
+ with st.form("prediction_form"):
166
+ # Demographic Information
167
+ st.markdown("**Demographic Information**")
168
+ col_demo1, col_demo2 = st.columns(2)
169
+
170
+ with col_demo1:
171
+ age = st.number_input("Age", min_value=17, max_value=90, value=39, key="age")
172
+ sex = st.selectbox("Sex", sex_options, key="sex")
173
+ race = st.selectbox("Race", race_options, key="race")
174
+
175
+ with col_demo2:
176
+ marital_status = st.selectbox("Marital Status", marital_status_options, key="marital_status")
177
+ relationship = st.selectbox("Relationship", relationship_options, key="relationship")
178
+ native_country = st.selectbox("Native Country", native_country_options, key="native_country")
179
+
180
+ st.divider()
181
+
182
+ # Work Information
183
+ st.markdown("**Work Information**")
184
+ col_work1, col_work2 = st.columns(2)
185
+
186
+ with col_work1:
187
+ workclass = st.selectbox("Work Class", workclass_options, key="workclass")
188
+ occupation = st.selectbox("Occupation", occupation_options, key="occupation")
189
+ hours_per_week = st.number_input("Hours per Week", min_value=1, max_value=99, value=40, key="hours_per_week")
190
+
191
+ with col_work2:
192
+ education_num = st.number_input("Education Level (Years)", min_value=1, max_value=16, value=10, key="education_num")
193
+ fnlwgt = st.number_input("Final Weight", min_value=12285, max_value=1484705, value=77516, key="fnlwgt")
194
+
195
+ st.divider()
196
+
197
+ # Financial Information
198
+ st.markdown("**Financial Information**")
199
+ col_fin1, col_fin2 = st.columns(2)
200
+
201
+ with col_fin1:
202
+ capital_gain = st.number_input("Capital Gain", min_value=0, max_value=99999, value=0, key="capital_gain")
203
+
204
+ with col_fin2:
205
+ capital_loss = st.number_input("Capital Loss", min_value=0, max_value=4356, value=0, key="capital_loss")
206
+
207
+ # Buttons
208
+ col_btn1, col_btn2, col_btn3 = st.columns(3)
209
+ with col_btn1:
210
+ predict_button = st.form_submit_button("๐Ÿ”ฎ Predict", type="primary")
211
+ with col_btn2:
212
+ reset_button = st.form_submit_button("๐Ÿ”„ Reset")
213
+ with col_btn3:
214
+ export_button = st.form_submit_button("๐Ÿ“ค Export Last Result")
215
+
216
+ # Handle reset button
217
+ if reset_button:
218
+ reset_session_state()
219
+ st.rerun()
220
+
221
+ # Handle prediction
222
+ if predict_button:
223
+ # Collect input data
224
+ input_data = {
225
+ 'age': age,
226
+ 'workclass': workclass,
227
+ 'fnlwgt': fnlwgt,
228
+ 'education_num': education_num,
229
+ 'marital_status': marital_status,
230
+ 'occupation': occupation,
231
+ 'relationship': relationship,
232
+ 'race': race,
233
+ 'sex': sex,
234
+ 'capital_gain': capital_gain,
235
+ 'capital_loss': capital_loss,
236
+ 'hours_per_week': hours_per_week,
237
+ 'native_country': native_country
238
+ }
239
+
240
+ # Validate inputs
241
+ validation_errors = validate_inputs(input_data)
242
+
243
+ if validation_errors:
244
+ with col2:
245
+ st.error("โŒ Validation Errors:")
246
+ for error in validation_errors:
247
+ st.error(f"โ€ข {error}")
248
+ else:
249
+ # Make prediction
250
+ try:
251
+ result = predict_income(input_data, model_components)
252
+
253
+ # Store result in session state for export
254
+ st.session_state['last_prediction'] = {
255
+ 'input_data': input_data,
256
+ 'result': result
257
+ }
258
+
259
+ with col2:
260
+ st.subheader("๐ŸŽฏ Prediction Results")
261
+
262
+ # Display prediction
263
+ prediction_color = "green" if result['prediction_label'] == '>50K' else "orange"
264
+ st.markdown(f"**Predicted Income:** :{prediction_color}[{result['prediction_label']}]")
265
+
266
+ # Confidence level with gauge
267
+ confidence = result['probability'] * 100
268
+
269
+ fig_gauge = go.Figure(go.Indicator(
270
+ mode = "gauge+number+delta",
271
+ value = confidence,
272
+ domain = {'x': [0, 1], 'y': [0, 1]},
273
+ title = {'text': "Confidence Level (%)"},
274
+ gauge = {
275
+ 'axis': {'range': [None, 100]},
276
+ 'bar': {'color': prediction_color},
277
+ 'steps': [
278
+ {'range': [0, 50], 'color': "lightgray"},
279
+ {'range': [50, 80], 'color': "yellow"},
280
+ {'range': [80, 100], 'color': "lightgreen"}
281
+ ],
282
+ 'threshold': {
283
+ 'line': {'color': "red", 'width': 4},
284
+ 'thickness': 0.75,
285
+ 'value': 90
286
+ }
287
+ }
288
+ ))
289
+ fig_gauge.update_layout(height=300, margin=dict(l=20, r=20, t=40, b=20))
290
+ st.plotly_chart(fig_gauge, use_container_width=True)
291
+
292
+ # Probability breakdown
293
+ prob_df = pd.DataFrame({
294
+ 'Class': ['โ‰ค50K', '>50K'],
295
+ 'Probability': result['probabilities']
296
+ })
297
+
298
+ fig_bar = px.bar(
299
+ prob_df,
300
+ x='Class',
301
+ y='Probability',
302
+ title='Probability Distribution',
303
+ color='Probability',
304
+ color_continuous_scale=['orange', 'green']
305
+ )
306
+ fig_bar.update_layout(height=300, margin=dict(l=20, r=20, t=40, b=20))
307
+ st.plotly_chart(fig_bar, use_container_width=True)
308
+
309
+ except Exception as e:
310
+ with col2:
311
+ st.error(f"โŒ Prediction Error: {str(e)}")
312
+
313
+ # Feature Importance section
314
+ st.subheader("๐Ÿ“Š Feature Importance")
315
+
316
+ if 'model' in model_components:
317
+ try:
318
+ feature_names = model_components['feature_names']
319
+ feature_importance = model_components['model'].feature_importances_
320
+
321
+ importance_df = pd.DataFrame({
322
+ 'Feature': feature_names,
323
+ 'Importance': feature_importance
324
+ }).sort_values('Importance', ascending=True)
325
+
326
+ fig_importance = px.bar(
327
+ importance_df,
328
+ x='Importance',
329
+ y='Feature',
330
+ orientation='h',
331
+ title='Feature Importance in Decision Tree Model',
332
+ color='Importance',
333
+ color_continuous_scale='viridis'
334
+ )
335
+ fig_importance.update_layout(height=400, margin=dict(l=20, r=20, t=40, b=20))
336
+ st.plotly_chart(fig_importance, use_container_width=True)
337
+
338
+ except Exception as e:
339
+ st.error(f"Error displaying feature importance: {str(e)}")
340
+
341
+ # Handle export
342
+ if export_button:
343
+ if 'last_prediction' in st.session_state:
344
+ export_data = export_prediction(
345
+ st.session_state['last_prediction']['input_data'],
346
+ st.session_state['last_prediction']['result']
347
+ )
348
+
349
+ st.download_button(
350
+ label="๐Ÿ“ฅ Download Prediction Results",
351
+ data=export_data,
352
+ file_name=f"income_prediction_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json",
353
+ mime="application/json"
354
+ )
355
+ else:
356
+ st.warning("โš ๏ธ No prediction results to export. Please make a prediction first.")
357
+
358
+ # Footer
359
+ st.markdown("---")
360
+ st.markdown("*Built with Streamlit โ€ข Dr. Eng. Farrikh Alzami, M.Kom*")