File size: 4,533 Bytes
8a9e98d
 
 
 
cd19a1e
8a9e98d
 
bdd4b0b
8a9e98d
 
 
 
 
 
 
f3854a0
496ea62
f3854a0
 
 
 
8a9e98d
 
45b2866
8a9e98d
 
 
 
 
 
 
 
 
f3854a0
8a9e98d
 
0f3b1f1
4eb404c
8a9e98d
 
0f3b1f1
8a9e98d
 
 
f3854a0
 
 
 
 
 
3c38e26
180149e
 
 
595b437
2b978e2
496ea62
3c38e26
e0b1ecb
595b437
0af6cee
 
45abae6
a24e56d
45abae6
be76a1f
595b437
 
2b978e2
0f3b1f1
 
 
2b978e2
f3854a0
 
a24e56d
f3854a0
 
 
 
 
 
45b2866
f3854a0
 
45b2866
b3268c8
a3c4a44
b3268c8
 
 
496ea62
 
 
f3854a0
7b0c295
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt

# load the model from disk
loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))

# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.

# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall):
    new_row = pd.DataFrame.from_dict({'age':age,'sex':sex,
              'cp':cp,'trtbps':trtbps,'chol':chol,
              'fbs':fbs, 'restecg':restecg,'thalachh':thalachh,'exng':exng,
                                     'oldpeak':oldpeak,'slp':slp,'caa':caa,'thall':thall}, 
                                     orient = 'index').transpose()
    
    prob = loaded_model.predict_proba(new_row)
    
    shap_values = explainer(new_row)
    # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
    # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
    plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)

    plt.tight_layout()
    local_plot = plt.gcf()
    plt.close()
    
    return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot

# Create the UI
title = "**Heart Attack Predictor & Interpreter** 🤓🫀"
description1 = """This app takes info from subjects and predicts their heart attack likelihood. Do not use for medical diagnosis."""

description2 = """
To use the app, you can either click on one of the examples or adjust the values of the factors, and click on Analyze. 🤞
""" 

with gr.Blocks(title=title) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("""---""")
    gr.Markdown(description2)
    gr.Markdown("""---""")

    with gr.Row():
        with gr.Column():
            gr.Markdown("""![Heart Attack!](file/heartattack.jpeg)""")

        with gr.Column():
            age = gr.Number(label="What's your age?")
            sex = gr.Dropdown(label="What's your sex?", choices = ["Female", "Male"],type="index")
            cp = gr.inputs.Dropdown(["typical", "atypical", "other", "asymptomatic"], label="Chest pain type")
        
        with gr.Column():        
            trtbps = gr.inputs.Slider(50, 180, default=80, label="Resting blood pressure") 
            chol = gr.Number(label="What is your cholesterol in (mg/dl)?", value=100, info = "cholestoral in mg/dl" )
            fbs = gr.Dropdown(label="Is your fasting blood sugar > 120 mg/dl?", choices = ["yes","no"], type = "index")
            restecg = gr.Dropdown(label="What is your resting ECG result?", choices = ["normal","ST-T wave abnormality"], type = "index", value = "normal", info = "resting ESG result")
            thall = gr.Dropdown(label="What is your Thalassemia condition?", choices = ["NULL","Fixed Defect", "Normal Blood Flow", "Reversible Defect"], type = "index", value = "NULL")
            
            
        with gr.Column():
            thalachh = gr.Number(label="What is your maximum heart rate?", value=100)
            exng = gr.Dropdown(label="exercise-induced angina", choices = ["yes","no"], type = "index", value = "yes")
            oldpeak = gr.Slider(label="ST depression induced by exercise relative to rest", minimum=0, maximum=10, value=0, step=1)
            slp = gr.Dropdown(label="Slope of the peak exercise ST segment", choices = ["upsloping","flat","downsloping"], type = "index", value = "flat")
            caa = gr.Dropdown(label="Degree of coronary artery anomaly", choices = ["0","1","2","3","4"], type = "index", value = "1")
          
        
    submit_btn = gr.Button("Process")

    with gr.Column(visible=True) as output_col:
        label = gr.Label(label = "Predicted Label")
        local_plot = gr.Plot(label = 'Shap:')

    submit_btn.click(
        main_func,
        [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall],
        [label,local_plot], api_name="Heart_Predictor"
    )
    
    gr.Markdown("### Click on any of the examples below to see how it works:")
    gr.Examples([[24,"male",4,70,200,"yes","normal",80,"yes",5,1,2,"Fixed Defect"], [24,"female",3,80,180,"no","normal",90,"no",1,1,2,"Reversible Defect"]], [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall], [label,local_plot], main_func, cache_examples=True)

demo.launch()