File size: 4,811 Bytes
bef5c48
79de571
bef5c48
 
 
 
 
 
 
 
 
 
26b9374
 
 
 
 
bef5c48
 
26b9374
bef5c48
 
 
 
 
 
 
 
 
 
 
 
 
 
26b9374
bef5c48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gradio as gr
import pandas as pd
import joblib, os

script_dir = os.path.dirname(os.path.abspath(__file__))
pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib')
model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib')

# Load transformation pipeline and model
pipeline = joblib.load(pipeline_path)
model = joblib.load(model_path)

# Load the heart.csv data
heart_data_path = os.path.join(script_dir, 'heart.csv')
heart_df = pd.read_csv(heart_data_path)

# Define the predict function
def predict(age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal):
    # Create a dataframe with the input data
    input_df = pd.DataFrame({
        'age': [age],
        'sex': [sex],
        'cp': [cp],
        'trestbps': [trestbps],
        'chol': [chol],
        'fbs': [fbs],
        'restecg': [restecg],
        'thalach': [thalach],
        'exang': [exang],
        'oldpeak': [oldpeak],
        'slope': [slope],
        'ca': [ca],
        'thal': [thal]
    })

    # Process input data using the pipeline
    X_processed = pipeline.transform(input_df)

    # Make predictions using the model
    prediction_probs = model.predict_proba(X_processed)[0]
    prediction_label = {
        "Prediction: CHURN 🔴": prediction_probs[1],
        "Prediction: STAY ✅": prediction_probs[0]
    }

    return prediction_label

input_interface = []

with gr.Blocks(theme=gr.themes.Soft()) as app:

    Title = gr.Label('Customer Churn Prediction App')

    with gr.Row():
        Title

    with gr.Row():
        gr.Markdown("This app predicts likelihood of a customer to leave or stay with the company")

    with gr.Row():
        with gr.Column():
            input_interface_column_1 = [
                gr.components.Slider(label='Age', minimum=0, maximum=120, step=1),
                gr.components.Radio([0, 1], label='Sex'),
                gr.components.Slider(label='Chest Pain Type', minimum=0, maximum=3, step=1),
                gr.components.Slider(label='Resting Blood Pressure', minimum=0, maximum=200, step=1),
                gr.components.Slider(label='Cholesterol', minimum=0, maximum=600, step=1),
                gr.components.Radio([0, 1], label='Fasting Blood Sugar > 120 mg/dl')
            ]

        with gr.Column():
            input_interface_column_2 = [
                gr.components.Slider(label='Resting ECG', minimum=0, maximum=2, step=1),
                gr.components.Slider(label='Max Heart Rate Achieved', minimum=60, maximum=220, step=1),
                gr.components.Radio([0, 1], label='Exercise Induced Angina'),
                gr.components.Slider(label='ST Depression Induced by Exercise', minimum=0.0, maximum=10.0, step=0.1),
                gr.components.Slider(label='Slope of Peak Exercise ST Segment', minimum=0, maximum=2, step=1),
                gr.components.Slider(label='Number of Major Vessels (0-3)', minimum=0, maximum=3, step=1),
                gr.components.Slider(label='Thalassemia (0-3)', minimum=0, maximum=3, step=1)
            ]

    with gr.Row():
        input_interface.extend(input_interface_column_1)
        input_interface.extend(input_interface_column_2)

    with gr.Row():
        predict_btn = gr.Button('Predict')
        output_interface = gr.Label(label="churn")
    
    with gr.Accordion("Open for information on inputs", open=False):
        gr.Markdown("""This app receives the following as inputs and processes them to return the prediction on whether a customer, will churn or not.
                    
                    - age: Age of the customer
                    - sex: Sex of the customer (0: Female, 1: Male)
                    - cp: Chest Pain Type (0: typical angina, 1: atypical angina, 2: non-anginal pain, 3: asymptomatic)
                    - trestbps: Resting Blood Pressure (in mm Hg on admission to the hospital)
                    - chol: Serum Cholesterol in mg/dl
                    - fbs: Fasting Blood Sugar > 120 mg/dl (0: No, 1: Yes)
                    - restecg: Resting Electrocardiographic results (0: normal, 1: having ST-T wave abnormality, 2: showing probable or definite left ventricular hypertrophy)
                    - thalach: Maximum Heart Rate Achieved
                    - exang: Exercise Induced Angina (0: No, 1: Yes)
                    - oldpeak: ST depression induced by exercise relative to rest
                    - slope: The slope of the peak exercise ST segment
                    - ca: Number of major vessels (0-3) colored by fluoroscopy
                    - thal: Thalassemia (0: normal, 1: fixed defect, 2: reversible defect, 3: unknown)
                    """)

    predict_btn.click(fn=predict, inputs=input_interface, outputs=output_interface)
        
app.launch(share=True)