File size: 3,921 Bytes
a795221
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import streamlit as st
import pandas as pd
import pickle
import numpy as np

st.markdown("<h1 style='text-align: center; font-size: 48px; color: red;'>Bank Customer Churn Prediction </h1>", unsafe_allow_html=True)

# Load the dataset
@st.cache_data
def load_dataset():
    return pd.read_csv('BankChurners.csv')  # Update with the path to your dataset

# Cache function to convert DataFrame to CSV
@st.cache_data
def convert_df(df):
    return df.to_csv(index=False).encode("utf-8")

# Load pre-trained model
@st.cache_resource
def load_model():
    with open("model.pkl", "rb") as f:
        model = pickle.load(f)
    return model

# Load the dataset
df = load_dataset()

# Display the dataset preview
st.write("Dataset Preview:")
st.dataframe(df.head())

# Ensure the dataset has the required columns
required_columns = ['CLIENTNUM', 'Total_Trans_Ct', 'Total_Ct_Chng_Q4_Q1', 'Total_Revolving_Bal', 
                    'Avg_Utilization_Ratio', 'Total_Trans_Amt', 'Total_Relationship_Count', 
                    'Total_Amt_Chng_Q4_Q1', 'Gender', 'Credit_Limit', 'Card_Category', 'Avg_Open_To_Buy']

if not all(col in df.columns for col in required_columns):
    st.error(f"The dataset must contain the following columns: {', '.join(required_columns)}")
else:
    # Add functionality for row selection
    st.markdown("### Select a Customer Row for Prediction:")

    # Option to select a row by index
    selected_row_index = st.selectbox("Select a Row Index", options=range(len(df)), index=0)

    # Add a button below the row selection for churn prediction
    predict_button = st.button("Predict Churn")

    # If the "Predict" button is clicked
    if predict_button:
        # Row to use for the model
        row_to_use = df.iloc[selected_row_index]

        # Prepare input data for the model
        gender_mapping = {"M": 0, "F": 1}
        card_category_mapping = {"Blue": 0, "Gold": 1, "Platinum": 2, "Titanium": 3}

        # Map categorical values to numerical ones
        row_to_use_for_model = [
            row_to_use['CLIENTNUM'],
            row_to_use['Total_Trans_Ct'],
            row_to_use['Total_Ct_Chng_Q4_Q1'],
            row_to_use['Total_Revolving_Bal'],
            row_to_use['Avg_Utilization_Ratio'],
            row_to_use['Total_Trans_Amt'],
            row_to_use['Total_Relationship_Count'],
            row_to_use['Total_Amt_Chng_Q4_Q1'],
            gender_mapping[row_to_use['Gender']],
            row_to_use['Credit_Limit'],
            card_category_mapping[row_to_use['Card_Category']],
            row_to_use['Avg_Open_To_Buy']
        ]

        # Load the model
        model = load_model()

        # Check if the number of features matches the model's expectations
        if len(row_to_use_for_model) != model.n_features_in_:
            st.error(f"The model expects {model.n_features_in_} features, but {len(row_to_use_for_model)} were provided.")
        else:
            # Apply the model for churn prediction
            prediction = model.predict([row_to_use_for_model])

            # Display the row and the churn prediction result
            st.write("Row selected for churn prediction:")
            st.write(row_to_use)

            # Show the prediction result
            result = "Likely to Churn" if prediction[0] == 1 else "Likely to Stay"
            st.write(f"Churn Prediction Result: {result}")

            # Provide option to download the result
            result_df = row_to_use.to_frame().T  # Convert Series to DataFrame
            result_df['Churn Prediction'] = result
            result_csv = convert_df(result_df)
            st.download_button(
                label="Download Prediction Result",
                data=result_csv,
                file_name="Churn_Prediction_Result.csv",
                mime="text/csv",
            )