heart_predict / app.py
Roberta2024's picture
Update app.py
26b9374 verified
import gradio as gr
import pandas as pd
import joblib, os
script_dir = os.path.dirname(os.path.abspath(__file__))
pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib')
model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib')
# Load transformation pipeline and model
pipeline = joblib.load(pipeline_path)
model = joblib.load(model_path)
# Load the heart.csv data
heart_data_path = os.path.join(script_dir, 'heart.csv')
heart_df = pd.read_csv(heart_data_path)
# Define the predict function
def predict(age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal):
# Create a dataframe with the input data
input_df = pd.DataFrame({
'age': [age],
'sex': [sex],
'cp': [cp],
'trestbps': [trestbps],
'chol': [chol],
'fbs': [fbs],
'restecg': [restecg],
'thalach': [thalach],
'exang': [exang],
'oldpeak': [oldpeak],
'slope': [slope],
'ca': [ca],
'thal': [thal]
})
# Process input data using the pipeline
X_processed = pipeline.transform(input_df)
# Make predictions using the model
prediction_probs = model.predict_proba(X_processed)[0]
prediction_label = {
"Prediction: CHURN 🔴": prediction_probs[1],
"Prediction: STAY ✅": prediction_probs[0]
}
return prediction_label
input_interface = []
with gr.Blocks(theme=gr.themes.Soft()) as app:
Title = gr.Label('Customer Churn Prediction App')
with gr.Row():
Title
with gr.Row():
gr.Markdown("This app predicts likelihood of a customer to leave or stay with the company")
with gr.Row():
with gr.Column():
input_interface_column_1 = [
gr.components.Slider(label='Age', minimum=0, maximum=120, step=1),
gr.components.Radio([0, 1], label='Sex'),
gr.components.Slider(label='Chest Pain Type', minimum=0, maximum=3, step=1),
gr.components.Slider(label='Resting Blood Pressure', minimum=0, maximum=200, step=1),
gr.components.Slider(label='Cholesterol', minimum=0, maximum=600, step=1),
gr.components.Radio([0, 1], label='Fasting Blood Sugar > 120 mg/dl')
]
with gr.Column():
input_interface_column_2 = [
gr.components.Slider(label='Resting ECG', minimum=0, maximum=2, step=1),
gr.components.Slider(label='Max Heart Rate Achieved', minimum=60, maximum=220, step=1),
gr.components.Radio([0, 1], label='Exercise Induced Angina'),
gr.components.Slider(label='ST Depression Induced by Exercise', minimum=0.0, maximum=10.0, step=0.1),
gr.components.Slider(label='Slope of Peak Exercise ST Segment', minimum=0, maximum=2, step=1),
gr.components.Slider(label='Number of Major Vessels (0-3)', minimum=0, maximum=3, step=1),
gr.components.Slider(label='Thalassemia (0-3)', minimum=0, maximum=3, step=1)
]
with gr.Row():
input_interface.extend(input_interface_column_1)
input_interface.extend(input_interface_column_2)
with gr.Row():
predict_btn = gr.Button('Predict')
output_interface = gr.Label(label="churn")
with gr.Accordion("Open for information on inputs", open=False):
gr.Markdown("""This app receives the following as inputs and processes them to return the prediction on whether a customer, will churn or not.
- age: Age of the customer
- sex: Sex of the customer (0: Female, 1: Male)
- cp: Chest Pain Type (0: typical angina, 1: atypical angina, 2: non-anginal pain, 3: asymptomatic)
- trestbps: Resting Blood Pressure (in mm Hg on admission to the hospital)
- chol: Serum Cholesterol in mg/dl
- fbs: Fasting Blood Sugar > 120 mg/dl (0: No, 1: Yes)
- restecg: Resting Electrocardiographic results (0: normal, 1: having ST-T wave abnormality, 2: showing probable or definite left ventricular hypertrophy)
- thalach: Maximum Heart Rate Achieved
- exang: Exercise Induced Angina (0: No, 1: Yes)
- oldpeak: ST depression induced by exercise relative to rest
- slope: The slope of the peak exercise ST segment
- ca: Number of major vessels (0-3) colored by fluoroscopy
- thal: Thalassemia (0: normal, 1: fixed defect, 2: reversible defect, 3: unknown)
""")
predict_btn.click(fn=predict, inputs=input_interface, outputs=output_interface)
app.launch(share=True)