File size: 6,259 Bytes
b6e76cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt

color = gr.themes.Default(primary_hue="pink").set(
    background_fill_primary="#ADD8E6",
    block_background_fill="#FFFFFF",
)

# load the model from disk
loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))

# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.

gender = {"Female":1,"Male":0}
chestpain = {"Typical angina":0,"Atypical angina":1,"Non-anginal pain":2,"Asymptomatic":3}
bloodsuguar = {"True":1, "False":0}
rest = {"Probable or Definite Left Ventricular Hypertrophy by Estes' Criteria":0, "Having ST - T Wave Abnormality":1,"Normal Value":2}
ex = {"Yes":1, "No":0}
slope = {"Upsloping":0, "Flat":1,"Downsloping":2}
thl = {"Reversible Defect":0, "Normal Blood Flow":1, "Fixed Defect":2}

# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall):
    new_row = pd.DataFrame.from_dict({'age':age,'sex':gender[sex],
              'cp':chestpain[cp],'trtbps':trtbps,'chol':chol,
              'fbs':bloodsuguar[fbs], 'restecg':rest[restecg],'thalachh':thalachh,'exng':ex[exng],
                                     'oldpeak':oldpeak,'slp':slope[slp],'caa':caa,'thall':thl[thall]}, 
                                     orient = 'index').transpose()
    
    prob = loaded_model.predict_proba(new_row)
    
    shap_values = explainer(new_row)
    # plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
    # plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
    plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)

    plt.tight_layout()
    local_plot = plt.gcf()
    plt.close()
    
    return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot

# Create the UI
title = "🩺**Heart Attack Risk Calculator** 🏥"
title_length = len(title)
title = title.center(title_length + 6) 
description1 = """
This app retrieves data from individuals and anticipates their probability of experiencing a heart attack. However, it should not be utilized for medical diagnosis purposes.   
"""

description2 = """
Please adjust the values of the factors. Then it will analyze your probability of getting a heart attack. 

**Age**: Age of the patient
**Sex**: Sex of the patient
**Chest Pain Type:** Type of chest pain experienced by the patient
**Resting Blood Pressure:** Blood pressure of the patient at rest, measured in millimeters of mercury (mm Hg)
**Cholesterol:** Cholesterol level of the patient, measured in milligrams per deciliter (mg/dl) and fetched via BMI sensor
**Fasting Blood Sugar:** Fasting blood sugar level of the patient, measured in milligrams per deciliter (mg/dl) and indicating whether it is greater than 120 mg/dl
**Resting Electrocardiographic Results:** Results of the resting electrocardiogram test
**Maximum Heart Rate Achieved:** Maximum heart rate achieved by the patient
**Exercise Induced Angina:** Whether the patient experienced exercise-induced angina or not
**Short-Term Depression Induced by Exercise Relative to Rest:** Amount of short-term depression in the ST segment induced by exercise relative to rest
**Slope of the Peak Exercise Short-Term Segment:** Slope of the peak exercise ST segment
**Number of Major Vessels:** Number of major vessels colored by fluoroscopy (0-3)
**Thalassemia:** Type of thalassemia the patient has, if any.

""" 

with gr.Blocks(title=title, theme=color) as demo:
    gr.Markdown(f"## {title}")
    gr.Markdown(description1)
    gr.Markdown("""---""")
    gr.Markdown(description2)
    gr.Markdown("""---""")
    
    age = gr.Slider(label="Age", minimum=0, maximum=100, value=40, step=1)
    sex = gr.Radio(label="Sex", choices =["Female","Male"])
    cp = gr.Dropdown(label="Chest Pain Type", choices = ["Typical angina", "Atypical angina", "Non-anginal pain", "Asymptomatic"])
    trtbps = gr.Slider(label="Resting Blood Pressure (in mm Hg)", minimum=90, maximum=200, value=90, step=1)
    chol = gr.Slider(label="Cholestoral in mg/dl", minimum=120, maximum=570, value=120, step=1)
    fbs = gr.Radio(label="Is your fasting blood sugar greater than 120mg/dl ", choices = ["True", "False"])
    restecg = gr.Radio(label="Resting Electrocardiographic Results", choices = ["Probable or Definite Left Ventricular Hypertrophy by Estes' Criteria", "Having ST - T Wave Abnormality","Normal Value"])
    thalachh = gr.Slider(label="Maximum Heart Rate", minimum=70, maximum=210, value=70, step=1) 
    exng = gr.Radio(label="Do You Have Exercise Induced Angina?", choices = ["Yes","No"])
    oldpeak = gr.Slider(label="Short-Term Depression Induced by Exercise Relative to Rest", minimum=0.0, maximum=6.5, value=0, step=0.1)
    slp = gr.Radio(label="Slope of the Peak Exercise", choices = ["Upsloping", "Flat", "Downsloping"])
    caa = gr.Slider(label="Number of Major Vessels", minimum=0, maximum=3, value=0, step=1)
    thall = gr.Radio(label="Type of Thalassemia", choices = ["Reversible Defect", "Normal Blood Flow", "Fixed Defect"])

   
    submit_btn = gr.Button("Analyze")

    with gr.Column(visible=True) as output_col:
        label = gr.Label(label = "Predicted Label")
        local_plot = gr.Plot(label = 'Shap:')

    submit_btn.click(
        main_func,
        [age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall],
        [label,local_plot], api_name="Risk Predictor"
    )
    
    gr.Markdown("### Click on any of the examples below to see how it works:")
    gr.Examples([[24,"Male","Atypical angina",200,500,"True","Normal Value",210,"Yes",6.5,"Upsloping",3,"Fixed Defect"], 
                 [22,"Female","Non-anginal pain",70,200,"False","Normal Value",120,"No",3.5,"Flat",3,"Fixed Defect"],
                 [23,"Female","Asymptomatic",90,250,"True","Normal Value",200,"Yes",5,"Upsloping",3,"Normal Blood Flow"]], 
                [age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall], [label,local_plot], main_func, 
                cache_examples=True)
   
demo.launch()