Group_4 / app.py
Jessycao's picture
Update app.py
cb2faa5
import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
color = gr.themes.Default(primary_hue="pink").set(
background_fill_primary="#ADD8E6",
block_background_fill="#FFFFFF",
)
# load the model from disk
loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
gender = {"Female":1,"Male":0}
chestpain = {"Typical angina":0,"Atypical angina":1,"Non-anginal pain":2,"Asymptomatic":3}
bloodsuguar = {"True":1, "False":0}
rest = {"Probable or Definite Left Ventricular Hypertrophy by Estes' Criteria":0, "Having ST - T Wave Abnormality":1,"Normal Value":2}
ex = {"Yes":1, "No":0}
slope = {"Upsloping":0, "Flat":1,"Downsloping":2}
thl = {"Reversible Defect":0, "Normal Blood Flow":1, "Fixed Defect":2}
# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall):
new_row = pd.DataFrame.from_dict({'age':age,'sex':gender[sex],
'cp':chestpain[cp],'trtbps':trtbps,'chol':chol,
'fbs':bloodsuguar[fbs], 'restecg':rest[restecg],'thalachh':thalachh,'exng':ex[exng],
'oldpeak':oldpeak,'slp':slope[slp],'caa':caa,'thall':thl[thall]},
orient = 'index').transpose()
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
# plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
# plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot
# Create the UI
title = "🩺**Heart Attack Risk Calculator** 🏥"
title_length = len(title)
title = title.center(title_length + 6)
description1 = """
This app retrieves data from individuals and anticipates their probability of experiencing a heart attack. However, it should not be utilized for medical diagnosis purposes.
"""
description2 = """
Please adjust the values of the factors. Then it will analyze your probability of getting a heart attack.
**Age**: Age of the patient
**Sex**: Sex of the patient
**Chest Pain Type:** Type of chest pain experienced by the patient
**Resting Blood Pressure:** Blood pressure of the patient at rest, measured in millimeters of mercury (mm Hg)
**Cholesterol:** Cholesterol level of the patient, measured in milligrams per deciliter (mg/dl) and fetched via BMI sensor
**Fasting Blood Sugar:** Fasting blood sugar level of the patient, measured in milligrams per deciliter (mg/dl) and indicating whether it is greater than 120 mg/dl
**Resting Electrocardiographic Results:** Results of the resting electrocardiogram test
**Maximum Heart Rate Achieved:** Maximum heart rate achieved by the patient
**Exercise Induced Angina:** Whether the patient experienced exercise-induced angina or not
**Short-Term Depression Induced by Exercise Relative to Rest:** Amount of short-term depression in the ST segment induced by exercise relative to rest
**Slope of the Peak Exercise Short-Term Segment:** Slope of the peak exercise ST segment
**Number of Major Vessels:** Number of major vessels colored by fluoroscopy (0-3)
**Thalassemia:** Type of thalassemia the patient has, if any.
"""
with gr.Blocks(title=title, theme=color) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
age = gr.Slider(label="Age", minimum=0, maximum=100, value=40, step=1)
sex = gr.Radio(label="Sex", choices =["Female","Male"])
cp = gr.Dropdown(label="Chest Pain Type", choices = ["Typical angina", "Atypical angina", "Non-anginal pain", "Asymptomatic"])
trtbps = gr.Slider(label="Resting Blood Pressure (in mm Hg)", minimum=90, maximum=200, value=90, step=1)
chol = gr.Slider(label="Cholestoral in mg/dl", minimum=120, maximum=570, value=120, step=1)
fbs = gr.Radio(label="Is your fasting blood sugar greater than 120mg/dl ", choices = ["True", "False"])
restecg = gr.Radio(label="Resting Electrocardiographic Results", choices = ["Probable or Definite Left Ventricular Hypertrophy by Estes' Criteria", "Having ST - T Wave Abnormality","Normal Value"])
thalachh = gr.Slider(label="Maximum Heart Rate", minimum=70, maximum=210, value=70, step=1)
exng = gr.Radio(label="Do You Have Exercise Induced Angina?", choices = ["Yes","No"])
oldpeak = gr.Slider(label="Short-Term Depression Induced by Exercise Relative to Rest", minimum=0.0, maximum=6.5, value=0, step=0.1)
slp = gr.Radio(label="Slope of the Peak Exercise", choices = ["Upsloping", "Flat", "Downsloping"])
caa = gr.Slider(label="Number of Major Vessels", minimum=0, maximum=3, value=0, step=1)
thall = gr.Radio(label="Type of Thalassemia", choices = ["Reversible Defect", "Normal Blood Flow", "Fixed Defect"])
submit_btn = gr.Button("Analyze")
with gr.Column(visible=True) as output_col:
label = gr.Label(label = "Predicted Label")
local_plot = gr.Plot(label = 'Shap:')
submit_btn.click(
main_func,
[age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall],
[label,local_plot], api_name="Risk Predictor"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([[24,"Male","Atypical angina",200,500,"True","Normal Value",210,"Yes",6.5,"Upsloping",3,"Fixed Defect"],
[22,"Female","Non-anginal pain",70,200,"False","Normal Value",120,"No",3.5,"Flat",3,"Fixed Defect"],
[23,"Female","Asymptomatic",90,250,"True","Normal Value",200,"Yes",5,"Upsloping",3,"Normal Blood Flow"]],
[age, sex, cp, trtbps, chol, fbs, restecg, thalachh, exng, oldpeak, slp, caa, thall], [label,local_plot], main_func,
cache_examples=True)
demo.launch()