fredo223 commited on
Commit
2fedeb9
·
verified ·
1 Parent(s): 7a4b932

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pandas as pd
3
+ import shap
4
+ from shap.plots._force_matplotlib import draw_additive_plot
5
+ import gradio as gr
6
+ import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from xgboost import XGBClassifier
9
+
10
+ # load the model from disk
11
+ loaded_model = xgb.XGBClassifier()
12
+ loaded_model.load_model("heart_xgb.json")
13
+
14
+ # Setup SHAP
15
+ explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
16
+
17
+ # Create the main function for server
18
+ def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh,
19
+ exng, oldpeak, slp, caa, thall):
20
+ new_row = pd.DataFrame.from_dict({'age': age,
21
+ 'sex':sex,
22
+ 'cp':cp,
23
+ 'trtbps':trtbps,
24
+ 'chol':chol,
25
+ 'fbs':fbs,
26
+ 'restecg':restecg,
27
+ 'thalachh':thalachh,
28
+ 'exng':exng,
29
+ 'oldpeak':oldpeak,
30
+ 'slp':slp,
31
+ 'caa':caa,
32
+ 'thall':thall
33
+ }, orient = 'index').transpose()
34
+
35
+ prob = loaded_model.predict_proba(new_row)
36
+
37
+ shap_values = explainer(new_row)
38
+ # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
39
+ # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
40
+ plot = shap.plots.bar(shap_values[0], max_display=7, order=shap.Explanation.abs, show_data='auto', show=False)
41
+
42
+ plt.tight_layout()
43
+ local_plot = plt.gcf()
44
+ plt.rcParams['figure.figsize'] = 7,4
45
+ plt.close()
46
+
47
+ return {"Normal Heart Condition": float(prob[0][0]), "Critical Heart Condition": 1-float(prob[0][0])}, local_plot
48
+
49
+ # Create the UI
50
+ title = "**Heart Condition Predictor & Interpreter** 🪐"
51
+ description1 = """
52
+ This app takes inputs about patients' demographics and medical history to predict whether the patient has heart condition. There are two outputs from the app: 1- the predicted probability of normal condition or heart condition, 2- Shapley's force-plot which visualizes the extent to which each factor impacts the prediction.
53
+ """
54
+
55
+ description2 = """
56
+ To use the app, click on one of the examples, or adjust the values of the patient factors, and click on Analyze. ✨
57
+ """
58
+
59
+ with gr.Blocks(title=title) as demo:
60
+ gr.Markdown(f"## {title}")
61
+ # gr.Markdown("""![marketing](types-of-employee-turnover.jpg)""")
62
+ gr.Markdown(description1)
63
+ gr.Markdown("""---""")
64
+ gr.Markdown(description2)
65
+ gr.Markdown("""---""")
66
+ with gr.Row():
67
+ with gr.Column():
68
+ age = gr.Slider(label="age", minimum=17, maximum=74, value=24, step=1)
69
+ sex = gr.Slider(label="sex", minimum=0, maximum=1, value=1, step=1)
70
+ cp = gr.Slider(label="cp Score", minimum=1, maximum=4, value=3, step=.1)
71
+ trtbps = gr.Slider(label="trestbps Score", minimum=94, maximum=200, value=150, step=.1)
72
+ chol = gr.Slider(label="chol Score", minimum=126, maximum=564, value=400, step=.1)
73
+ fbs = gr.Slider(label="fbs Score", minimum=0, maximum=1, value=0, step=.1)
74
+ restecg = gr.Slider(label="restecg Score", minimum=0, maximum=2, value=1, step=.1)
75
+ thalachh = gr.Slider(label="thalach Score", minimum=71, maximum=202, value=90, step=.1)
76
+ exng = gr.Slider(label="exang Score", minimum=0, maximum=1, value=1, step=.1)
77
+ oldpeak = gr.Slider(label="oldpeak Score", minimum=0, maximum=6, value=4, step=.1)
78
+ slp = gr.Slider(label="slope Score", minimum=1, maximum=3, value=2, step=.1)
79
+ caa = gr.Slider(label="ca Score", minimum=0, maximum=3, value=2, step=.1)
80
+ thall = gr.Slider(label="thal Score", minimum=3, maximum=7, value=4, step=.1)
81
+
82
+ submit_btn = gr.Button("Analyze")
83
+ with gr.Column(visible=True,scale=1, min_width=600) as output_col:
84
+ label = gr.Label(label = "Predicted Label")
85
+ local_plot = gr.Plot(label = 'Shap:')
86
+
87
+ submit_btn.click(
88
+ main_func,
89
+ [age, sex,cp,trtbps,chol,fbs,restecg,thalachh,exng,oldpeak,slp,caa,thall],
90
+ [label,local_plot], api_name="Heart_Condition"
91
+ )
92
+
93
+ gr.Markdown("### Click on any of the examples below to see how it works:")
94
+ gr.Examples([[33,0,1,100,230,1,1,150,0,.9,2,1,6], [39,1,0,170,200,1,1,150,0,1.4,2,1,6]],
95
+ [age, sex,cp,trtbps,chol,fbs,restecg,thalachh,exng,oldpeak,slp,caa,thall],
96
+ [label,local_plot], main_func, cache_examples=True)
97
+
98
+ demo.launch()