File size: 4,533 Bytes
8a9e98d cd19a1e 8a9e98d bdd4b0b 8a9e98d f3854a0 496ea62 f3854a0 8a9e98d 45b2866 8a9e98d f3854a0 8a9e98d 0f3b1f1 4eb404c 8a9e98d 0f3b1f1 8a9e98d f3854a0 3c38e26 180149e 595b437 2b978e2 496ea62 3c38e26 e0b1ecb 595b437 0af6cee 45abae6 a24e56d 45abae6 be76a1f 595b437 2b978e2 0f3b1f1 2b978e2 f3854a0 a24e56d f3854a0 45b2866 f3854a0 45b2866 b3268c8 a3c4a44 b3268c8 496ea62 f3854a0 7b0c295 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import pickle
import pandas as pd
import shap
from shap.plots._force_matplotlib import draw_additive_plot
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
# load the model from disk
loaded_model = pickle.load(open("heart_xgb.pkl", 'rb'))
# Setup SHAP
explainer = shap.Explainer(loaded_model) # PLEASE DO NOT CHANGE THIS.
# Create the main function for server
def main_func(age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall):
new_row = pd.DataFrame.from_dict({'age':age,'sex':sex,
'cp':cp,'trtbps':trtbps,'chol':chol,
'fbs':fbs, 'restecg':restecg,'thalachh':thalachh,'exng':exng,
'oldpeak':oldpeak,'slp':slp,'caa':caa,'thall':thall},
orient = 'index').transpose()
prob = loaded_model.predict_proba(new_row)
shap_values = explainer(new_row)
# plot = shap.force_plot(shap_values[0], matplotlib=True, figsize=(30,30), show=False)
# plot = shap.plots.waterfall(shap_values[0], max_display=6, show=False)
plot = shap.plots.bar(shap_values[0], max_display=6, order=shap.Explanation.abs, show_data='auto', show=False)
plt.tight_layout()
local_plot = plt.gcf()
plt.close()
return {"Low Chance": float(prob[0][0]), "High Chance": 1-float(prob[0][0])}, local_plot
# Create the UI
title = "**Heart Attack Predictor & Interpreter** 🤓🫀"
description1 = """This app takes info from subjects and predicts their heart attack likelihood. Do not use for medical diagnosis."""
description2 = """
To use the app, you can either click on one of the examples or adjust the values of the factors, and click on Analyze. 🤞
"""
with gr.Blocks(title=title) as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description1)
gr.Markdown("""---""")
gr.Markdown(description2)
gr.Markdown("""---""")
with gr.Row():
with gr.Column():
gr.Markdown("""""")
with gr.Column():
age = gr.Number(label="What's your age?")
sex = gr.Dropdown(label="What's your sex?", choices = ["Female", "Male"],type="index")
cp = gr.inputs.Dropdown(["typical", "atypical", "other", "asymptomatic"], label="Chest pain type")
with gr.Column():
trtbps = gr.inputs.Slider(50, 180, default=80, label="Resting blood pressure")
chol = gr.Number(label="What is your cholesterol in (mg/dl)?", value=100, info = "cholestoral in mg/dl" )
fbs = gr.Dropdown(label="Is your fasting blood sugar > 120 mg/dl?", choices = ["yes","no"], type = "index")
restecg = gr.Dropdown(label="What is your resting ECG result?", choices = ["normal","ST-T wave abnormality"], type = "index", value = "normal", info = "resting ESG result")
thall = gr.Dropdown(label="What is your Thalassemia condition?", choices = ["NULL","Fixed Defect", "Normal Blood Flow", "Reversible Defect"], type = "index", value = "NULL")
with gr.Column():
thalachh = gr.Number(label="What is your maximum heart rate?", value=100)
exng = gr.Dropdown(label="exercise-induced angina", choices = ["yes","no"], type = "index", value = "yes")
oldpeak = gr.Slider(label="ST depression induced by exercise relative to rest", minimum=0, maximum=10, value=0, step=1)
slp = gr.Dropdown(label="Slope of the peak exercise ST segment", choices = ["upsloping","flat","downsloping"], type = "index", value = "flat")
caa = gr.Dropdown(label="Degree of coronary artery anomaly", choices = ["0","1","2","3","4"], type = "index", value = "1")
submit_btn = gr.Button("Process")
with gr.Column(visible=True) as output_col:
label = gr.Label(label = "Predicted Label")
local_plot = gr.Plot(label = 'Shap:')
submit_btn.click(
main_func,
[age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall],
[label,local_plot], api_name="Heart_Predictor"
)
gr.Markdown("### Click on any of the examples below to see how it works:")
gr.Examples([[24,"male",4,70,200,"yes","normal",80,"yes",5,1,2,"Fixed Defect"], [24,"female",3,80,180,"no","normal",90,"no",1,1,2,"Reversible Defect"]], [age, sex, cp, trtbps, chol, fbs, restecg, thalachh,exng,oldpeak,slp,caa,thall], [label,local_plot], main_func, cache_examples=True)
demo.launch()
|