rhk8hy commited on
Commit
09b692d
·
1 Parent(s): 849966d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -38
app.py CHANGED
@@ -5,43 +5,35 @@ from shap.plots._force_matplotlib import draw_additive_plot
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.pyplot as plt
8
-
9
  # load the model from disk
10
  loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))
11
-
12
  # Setup SHAP
13
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
14
-
15
  # Create the main function for server
16
- def main_func(age, sex,cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp,caa,thall):
17
  new_row = pd.DataFrame.from_dict({'age':age,'sex':sex,
18
- 'cp':cp,'trtbps':trtbps,'chol':chol,
19
- 'fbs':fbs,'restecg': restecg, 'thalachh':thalachh, 'exng':exng, 'oldpeak':oldpeak, 'slp':slp, 'caa':caa,'thall':thall},
20
- orient = 'index').transpose()
21
-
22
  prob = loaded_model.predict_proba(new_row)
23
-
24
  shap_values = explainer(new_row)
25
  # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
26
  # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
27
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
28
-
29
  plt.tight_layout()
30
  local_plot = plt.gcf()
31
  plt.close()
32
-
33
- return {"low chance": float(prob[0][0]), "high chance": 1-float(prob[0][0])}, local_plot
34
-
35
  # Create the UI
36
  title = "**Heart Attack Predictor & Interpreter** 🪐"
37
  description1 = """
38
- This app takes info from subjects and predicts their heart attack likelihood. Do not use for medical purposes .✨
39
  """
40
-
41
  description2 = """
42
- 🤞
43
  """
44
-
45
  with gr.Blocks(title=title) as demo:
46
  gr.Markdown(f"## {title}")
47
  # gr.Markdown("""![marketing](file/marketing.jpg)""")
@@ -49,35 +41,29 @@ with gr.Blocks(title=title) as demo:
49
  gr.Markdown("""---""")
50
  gr.Markdown(description2)
51
  gr.Markdown("""---""")
52
- age = gr.Slider(label="age score", minimum=15, maximum=90, value=40, step=1)
53
  sex = gr.Slider(label="sex score", minimum=0, maximum=1, value=1, step=1)
54
  cp = gr.Slider(label="cp score", minimum=1, maximum=5, value=4, step=1)
55
- trtbps = gr.Slider(label="trtbps score", minimum=1, maximum=5, value=4, step=1)
56
- chol = gr.Slider(label="chol score", minimum=1, maximum=5, value=4, step=1)
57
- fbs = gr.Slider(label="fbs score", minimum=1, maximum=5, value=4, step=1)
58
- restecg = gr.Slider(label="restecg score", minimum=1, maximum=5, value=4, step=1)
59
- thalachh = gr.Slider(label="thalachh score", minimum=1, maximum=5, value=4, step=1)
60
- exng = gr.Slider(label="exng score", minimum=1, maximum=5, value=4, step=1)
61
- oldpeak = gr.Slider(label="oldpeak score", minimum=1, maximum=5, value=4, step=1)
62
- slp = gr.Slider(label="slp score", minimum=1, maximum=5, value=4, step=1)
63
- caa = gr.Slider(label="caa score", minimum=1, maximum=5, value=4, step=1)
64
- thall = gr.Slider(label="thall score", minimum=1, maximum=5, value=4, step=1)
65
-
66
-
67
-
68
  submit_btn = gr.Button("Analyze")
69
-
70
  with gr.Column(visible=True) as output_col:
71
  label = gr.Label(label = "Predicted Label")
72
  local_plot = gr.Plot(label = 'Shap:')
73
-
74
  submit_btn.click(
75
  main_func,
76
- [age, sex,cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp,caa,thall],
77
- [label,local_plot], api_name="Heart_Predictor"
78
  )
79
-
80
  gr.Markdown("### Click on any of the examples below to see how it works:")
81
- gr.Examples([[24,0,4,4,5,5,4,4,5,5,1,2,3], [20,0,4,4,5,5,4,3,5,5,1,2,3]], [age, sex,cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp,caa,thall], [label,local_plot], main_func, cache_examples=True)
82
-
83
  demo.launch()
 
5
  import gradio as gr
6
  import numpy as np
7
  import matplotlib.pyplot as plt
 
8
  # load the model from disk
9
  loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))
 
10
  # Setup SHAP
11
  explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
 
12
  # Create the main function for server
13
+ def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall):
14
  new_row = pd.DataFrame.from_dict({'age':age,'sex':sex,
15
+ 'cp':cp,'trtbps':trtbps,'chol':chol, 'fbs':fbs, 'restecg':restecg,
16
+ 'thalachh':thalachh, 'exng':exng, 'oldpeak':oldpeak, 'slp':slp, 'caa':caa, 'thall':thall}, orient = 'index').transpose()
17
+
 
18
  prob = loaded_model.predict_proba(new_row)
19
+
20
  shap_values = explainer(new_row)
21
  # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
22
  # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
23
  plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
 
24
  plt.tight_layout()
25
  local_plot = plt.gcf()
26
  plt.close()
27
+
28
+ return {"Leave": float(prob[0][0]), "Stay": 1-float(prob[0][0])}, local_plot
 
29
  # Create the UI
30
  title = "**Heart Attack Predictor & Interpreter** 🪐"
31
  description1 = """
32
+ This app takes info from subjects and predicts their heart attack likelihood. Do not use for medical diagnosis.
33
  """
 
34
  description2 = """
35
+ To use the app, click on one of the examples, or adjust the values of the factors, and click on Analyze. 🤞
36
  """
 
37
  with gr.Blocks(title=title) as demo:
38
  gr.Markdown(f"## {title}")
39
  # gr.Markdown("""![marketing](file/marketing.jpg)""")
 
41
  gr.Markdown("""---""")
42
  gr.Markdown(description2)
43
  gr.Markdown("""---""")
44
+ age = gr.Slider(label="age score", minimum=15, maximum=90, value=40, step=5)
45
  sex = gr.Slider(label="sex score", minimum=0, maximum=1, value=1, step=1)
46
  cp = gr.Slider(label="cp score", minimum=1, maximum=5, value=4, step=1)
47
+ trtbps = gr.Slider(label="trtbps Score", minimum=1, maximum=5, value=4, step=1)
48
+ chol = gr.Slider(label="chol Score", minimum=1, maximum=5, value=4, step=1)
49
+ fbs = gr.Slider(label="fbs Score", minimum=1, maximum=5, value=4, step=1)
50
+ restecg = gr.Slider(label="restecg Score", minimum=1, maximum=5, value=4, step=1)
51
+ thalachh = gr.Slider(label="thalachh Score", minimum=1, maximum=5, value=4, step=1)
52
+ exng = gr.Slider(label="exng Score", minimum=1, maximum=5, value=4, step=1)
53
+ oldpeak = gr.Slider(label="oldpeak Score", minimum=1, maximum=5, value=4, step=1)
54
+ slp = gr.Slider(label="slp Score", minimum=1, maximum=5, value=4, step=1)
55
+ caa = gr.Slider(label="caa Score", minimum=1, maximum=5, value=4, step=1)
56
+ thall = gr.Slider(label="thall Score", minimum=1, maximum=5, value=4, step=1)
 
 
 
57
  submit_btn = gr.Button("Analyze")
 
58
  with gr.Column(visible=True) as output_col:
59
  label = gr.Label(label = "Predicted Label")
60
  local_plot = gr.Plot(label = 'Shap:')
 
61
  submit_btn.click(
62
  main_func,
63
+ [age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall],
64
+ [label,local_plot], api_name="Employee_Turnover"
65
  )
66
+
67
  gr.Markdown("### Click on any of the examples below to see how it works:")
68
+ gr.Examples([[24,0,4,4,5,4,4,5,5,1,2,3,4], [20,0,3,4,5,4,4,5,5,1,2,3,3]], [age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall], [label,local_plot], main_func, cache_examples=True)
 
69
  demo.launch()