AlbertMaaaa commited on
Commit
c0d3dcc
ยท
verified ยท
1 Parent(s): 31292dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -49
app.py CHANGED
@@ -1,59 +1,111 @@
1
  import pickle
2
  import pandas as pd
3
- import gradio as gr
4
  import shap
 
 
 
5
  import matplotlib.pyplot as plt
6
 
7
- # Load model and SHAP explainer
8
- model = pickle.load(open("default_xgb.pkl", "rb"))
9
- explainer = shap.Explainer(model)
10
-
11
- # Predict function using threshold = 0.4 and SHAP bar chart
12
- def predict_default(education, limit_bal, pay_0, pay_2, pay_3, pay_4, pay_5, pay_6):
13
- input_data = pd.DataFrame([{
14
- 'EDUCATION': education,
15
- 'LIMIT_BAL': limit_bal,
16
- 'PAY_0': pay_0,
17
- 'PAY_2': pay_2,
18
- 'PAY_3': pay_3,
19
- 'PAY_4': pay_4,
20
- 'PAY_5': pay_5,
21
- 'PAY_6': pay_6
22
- }])
23
-
24
- # Predict probability and apply threshold
25
- prob = model.predict_proba(input_data)[0, 1]
26
- prediction = "Default" if prob >= 0.4 else "No Default"
27
-
28
- # SHAP explanation
29
- shap_values = explainer(input_data)
 
 
 
 
 
 
 
 
 
30
  plt.figure()
31
- shap.plots.bar(shap_values[0], max_display=7, show=False)
32
- plt.tight_layout()
33
- fig = plt.gcf()
34
  plt.close()
35
 
36
- return f"Prediction: {prediction} (Probability: {prob:.2f})", fig
37
-
38
- # Gradio Interface
39
- demo = gr.Interface(
40
- fn=predict_default,
41
- inputs=[
42
- gr.Dropdown([1, 2, 3, 4], label="Education Level"),
43
- gr.Number(label="Credit Limit"),
44
- gr.Number(label="Repayment Status Sept (PAY_0)"),
45
- gr.Number(label="Repayment Status Aug (PAY_2)"),
46
- gr.Number(label="Repayment Status July (PAY_3)"),
47
- gr.Number(label="Repayment Status June (PAY_4)"),
48
- gr.Number(label="Repayment Status May (PAY_5)"),
49
- gr.Number(label="Repayment Status April (PAY_6)")
50
- ],
51
- outputs=[
52
- gr.Text(label="Prediction Result"),
53
- gr.Plot(label="SHAP Explanation")
54
- ],
55
- title="๐Ÿ“Š Credit Default Predictor with SHAP",
56
- description="An XGBoost model using Optuna hyperparameters and a 0.4 threshold to predict credit default risk. SHAP plot shows which features drive the prediction."
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  demo.launch()
 
1
  import pickle
2
  import pandas as pd
 
3
  import shap
4
+ from shap.plots._force_matplotlib import draw_additive_plot
5
+ import gradio as gr
6
+ import numpy as np
7
  import matplotlib.pyplot as plt
8
 
9
+ # Load model
10
+ loaded_model = pickle.load(open("default_xgb.pkl", 'rb'))
11
+
12
+ # SHAP Explainer
13
+ explainer = shap.Explainer(loaded_model) # DO NOT CHANGE
14
+
15
+ # Define main function
16
+ def main_func(LIMIT_BAL, EDUCATION,
17
+ PAY_0, PAY_2, PAY_3, PAY_4, PAY_5, PAY_6,
18
+ b1, b2, b3, b4, b5, b6):
19
+
20
+ new_row = pd.DataFrame.from_dict({
21
+ 'LIMIT_BAL': [LIMIT_BAL],
22
+ 'EDUCATION': [EDUCATION],
23
+ 'PAY_0': [PAY_0],
24
+ 'PAY_2': [PAY_2],
25
+ 'PAY_3': [PAY_3],
26
+ 'PAY_4': [PAY_4],
27
+ 'PAY_5': [PAY_5],
28
+ 'PAY_6': [PAY_6],
29
+ 'b1': [b1],
30
+ 'b2': [b2],
31
+ 'b3': [b3],
32
+ 'b4': [b4],
33
+ 'b5': [b5],
34
+ 'b6': [b6]
35
+ })
36
+
37
+ prob = loaded_model.predict_proba(new_row)[0][1]
38
+ prediction = int(prob > 0.4)
39
+
40
+ shap_values = explainer(new_row)
41
  plt.figure()
42
+ shap.plots.bar(shap_values[0], max_display=10, show=False)
43
+ local_plot = plt.gcf()
 
44
  plt.close()
45
 
46
+ return {"No Default": 1 - float(prob), "Default": float(prob)}, local_plot
47
+
48
+ # App UI metadata
49
+ title = "**Default Risk Predictor & Interpreter** ๐Ÿ“Š"
50
+ description1 = """
51
+ This app uses financial data such as credit limits, repayment behavior, and bill balances to estimate the probability that a person may **default on their next credit payment**.
52
+ It also includes an interactive SHAP plot that explains which features most influenced the prediction.
53
+ """
54
+ description2 = """
55
+ To use the app, either adjust the values or click one of the examples. Then click on **Analyze** to get predictions and interpretations. ๐Ÿ”
56
+ """
57
+
58
+ # Launch Gradio app
59
+ with gr.Blocks(title=title) as demo:
60
+ gr.Markdown(f"## {title}")
61
+ gr.Markdown(description1)
62
+ gr.Markdown("---")
63
+ gr.Markdown(description2)
64
+ gr.Markdown("---")
65
+
66
+ with gr.Row():
67
+ with gr.Column():
68
+ LIMIT_BAL = gr.Slider(label="Credit Limit (LIMIT_BAL)", minimum=10000, maximum=1000000, value=150000, step=5000)
69
+ EDUCATION = gr.Slider(label="Education Level (EDUCATION)", minimum=0, maximum=6, value=2, step=1)
70
+ PAY_0 = gr.Slider(label="PAY_0 (Most recent payment status)", minimum=-2, maximum=8, value=0, step=1)
71
+ PAY_2 = gr.Slider(label="PAY_2", minimum=-2, maximum=8, value=0, step=1)
72
+ PAY_3 = gr.Slider(label="PAY_3", minimum=-2, maximum=8, value=0, step=1)
73
+ PAY_4 = gr.Slider(label="PAY_4", minimum=-2, maximum=8, value=0, step=1)
74
+ PAY_5 = gr.Slider(label="PAY_5", minimum=-2, maximum=8, value=0, step=1)
75
+ PAY_6 = gr.Slider(label="PAY_6", minimum=-2, maximum=8, value=0, step=1)
76
+ b1 = gr.Slider(label="Balance 1 (b1)", minimum=-50000, maximum=50000, value=5000, step=500)
77
+ b2 = gr.Slider(label="Balance 2 (b2)", minimum=-50000, maximum=50000, value=4500, step=500)
78
+ b3 = gr.Slider(label="Balance 3 (b3)", minimum=-50000, maximum=50000, value=4700, step=500)
79
+ b4 = gr.Slider(label="Balance 4 (b4)", minimum=-50000, maximum=50000, value=4300, step=500)
80
+ b5 = gr.Slider(label="Balance 5 (b5)", minimum=-50000, maximum=50000, value=4000, step=500)
81
+ b6 = gr.Slider(label="Balance 6 (b6)", minimum=-50000, maximum=50000, value=3900, step=500)
82
+
83
+ submit_btn = gr.Button("Analyze")
84
+
85
+ with gr.Column(visible=True, scale=1, min_width=600) as output_col:
86
+ label = gr.Label(label="Predicted Default Risk")
87
+ local_plot = gr.Plot(label="SHAP Explanation Plot")
88
+
89
+ submit_btn.click(
90
+ main_func,
91
+ [LIMIT_BAL, EDUCATION,
92
+ PAY_0, PAY_2, PAY_3, PAY_4, PAY_5, PAY_6,
93
+ b1, b2, b3, b4, b5, b6],
94
+ [label, local_plot],
95
+ api_name="Default_Risk"
96
+ )
97
+
98
+ gr.Markdown("### Example inputs to try:")
99
+ gr.Examples(
100
+ examples=[
101
+ [200000, 2, 0, 0, 0, 0, 0, 0, 5200, 5000, 4800, 4600, 4400, 4200],
102
+ [80000, 3, 2, 0, 2, 3, 2, 2, -3000, -2500, -2000, -1800, -1500, -1200]
103
+ ],
104
+ inputs=[LIMIT_BAL, EDUCATION, PAY_0, PAY_2, PAY_3, PAY_4, PAY_5, PAY_6,
105
+ b1, b2, b3, b4, b5, b6],
106
+ outputs=[label, local_plot],
107
+ fn=main_func,
108
+ cache_examples=True
109
+ )
110
 
111
  demo.launch()