Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import pandas as pd | |
| import joblib, os | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| pipeline_path = os.path.join(script_dir, 'toolkit', 'pipeline.joblib') | |
| model_path = os.path.join(script_dir, 'toolkit', 'Random Forest Classifier.joblib') | |
| # Load transformation pipeline and model | |
| pipeline = joblib.load(pipeline_path) | |
| model = joblib.load(model_path) | |
| # Load the heart.csv data | |
| heart_data_path = os.path.join(script_dir, 'heart.csv') | |
| heart_df = pd.read_csv(heart_data_path) | |
| # Define the predict function | |
| def predict(age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal): | |
| # Create a dataframe with the input data | |
| input_df = pd.DataFrame({ | |
| 'age': [age], | |
| 'sex': [sex], | |
| 'cp': [cp], | |
| 'trestbps': [trestbps], | |
| 'chol': [chol], | |
| 'fbs': [fbs], | |
| 'restecg': [restecg], | |
| 'thalach': [thalach], | |
| 'exang': [exang], | |
| 'oldpeak': [oldpeak], | |
| 'slope': [slope], | |
| 'ca': [ca], | |
| 'thal': [thal] | |
| }) | |
| # Process input data using the pipeline | |
| X_processed = pipeline.transform(input_df) | |
| # Make predictions using the model | |
| prediction_probs = model.predict_proba(X_processed)[0] | |
| prediction_label = { | |
| "Prediction: CHURN 🔴": prediction_probs[1], | |
| "Prediction: STAY ✅": prediction_probs[0] | |
| } | |
| return prediction_label | |
| input_interface = [] | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| Title = gr.Label('Customer Churn Prediction App') | |
| with gr.Row(): | |
| Title | |
| with gr.Row(): | |
| gr.Markdown("This app predicts likelihood of a customer to leave or stay with the company") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_interface_column_1 = [ | |
| gr.components.Slider(label='Age', minimum=0, maximum=120, step=1), | |
| gr.components.Radio([0, 1], label='Sex'), | |
| gr.components.Slider(label='Chest Pain Type', minimum=0, maximum=3, step=1), | |
| gr.components.Slider(label='Resting Blood Pressure', minimum=0, maximum=200, step=1), | |
| gr.components.Slider(label='Cholesterol', minimum=0, maximum=600, step=1), | |
| gr.components.Radio([0, 1], label='Fasting Blood Sugar > 120 mg/dl') | |
| ] | |
| with gr.Column(): | |
| input_interface_column_2 = [ | |
| gr.components.Slider(label='Resting ECG', minimum=0, maximum=2, step=1), | |
| gr.components.Slider(label='Max Heart Rate Achieved', minimum=60, maximum=220, step=1), | |
| gr.components.Radio([0, 1], label='Exercise Induced Angina'), | |
| gr.components.Slider(label='ST Depression Induced by Exercise', minimum=0.0, maximum=10.0, step=0.1), | |
| gr.components.Slider(label='Slope of Peak Exercise ST Segment', minimum=0, maximum=2, step=1), | |
| gr.components.Slider(label='Number of Major Vessels (0-3)', minimum=0, maximum=3, step=1), | |
| gr.components.Slider(label='Thalassemia (0-3)', minimum=0, maximum=3, step=1) | |
| ] | |
| with gr.Row(): | |
| input_interface.extend(input_interface_column_1) | |
| input_interface.extend(input_interface_column_2) | |
| with gr.Row(): | |
| predict_btn = gr.Button('Predict') | |
| output_interface = gr.Label(label="churn") | |
| with gr.Accordion("Open for information on inputs", open=False): | |
| gr.Markdown("""This app receives the following as inputs and processes them to return the prediction on whether a customer, will churn or not. | |
| - age: Age of the customer | |
| - sex: Sex of the customer (0: Female, 1: Male) | |
| - cp: Chest Pain Type (0: typical angina, 1: atypical angina, 2: non-anginal pain, 3: asymptomatic) | |
| - trestbps: Resting Blood Pressure (in mm Hg on admission to the hospital) | |
| - chol: Serum Cholesterol in mg/dl | |
| - fbs: Fasting Blood Sugar > 120 mg/dl (0: No, 1: Yes) | |
| - restecg: Resting Electrocardiographic results (0: normal, 1: having ST-T wave abnormality, 2: showing probable or definite left ventricular hypertrophy) | |
| - thalach: Maximum Heart Rate Achieved | |
| - exang: Exercise Induced Angina (0: No, 1: Yes) | |
| - oldpeak: ST depression induced by exercise relative to rest | |
| - slope: The slope of the peak exercise ST segment | |
| - ca: Number of major vessels (0-3) colored by fluoroscopy | |
| - thal: Thalassemia (0: normal, 1: fixed defect, 2: reversible defect, 3: unknown) | |
| """) | |
| predict_btn.click(fn=predict, inputs=input_interface, outputs=output_interface) | |
| app.launch(share=True) | |