| import streamlit as st |
| import numpy as np |
| import pickle |
| from tensorflow.keras.models import load_model |
|
|
| |
| @st.cache_resource |
| def load_keras_model(): |
| """Load the saved Keras model.""" |
| try: |
| model = load_model("model_churn.keras") |
| return model |
| except Exception as e: |
| st.error(f"Error loading Keras model (model_churn.keras): {e}") |
| return None |
|
|
| |
| @st.cache_resource |
| def load_scaler_model(): |
| """Load the saved Scaler (pickle).""" |
| try: |
| |
| with open("scaler_churn", "rb") as f: |
| scaler = pickle.load(f) |
| return scaler |
| except FileNotFoundError: |
| st.error("Error: 'scaler_churn' file not found. Please make sure the file is uploaded.") |
| return None |
| except Exception as e: |
| st.error(f"Error loading Scaler (scaler_churn): {e}") |
| return None |
| |
|
|
| |
| model = load_keras_model() |
| scaler = load_scaler_model() |
|
|
| |
| MODEL_INPUT_COLUMNS_ORDER = [ |
| 'CreditScore', 'Age', 'Tenure', 'Balance', 'NumOfProducts', 'HasCrCard', |
| 'IsActiveMember', 'EstimatedSalary', 'Geography_Germany', 'Geography_Spain', |
| 'Gender_Male' |
| ] |
|
|
| |
| st.set_page_config(page_title="Customer Churn Prediction", layout="wide") |
| st.title("🏦 Customer Churn Prediction Model") |
|
|
| col1, col2 = st.columns([1, 1]) |
|
|
| with col1: |
| st.header("Customer Information") |
| st.caption("Please enter customer information.") |
| |
| geography = st.selectbox("Country", ("France", "Spain", "Germany"), index=0) |
| gender = st.radio("Gender", ("Female", "Male"), index=0) |
| has_cr_card = st.radio("Has Credit Card?", ("Yes", "No"), index=0) |
| is_active_member = st.radio("Active Member?", ("Yes", "No"), index=0) |
| |
| credit_score = st.number_input("Credit Score", min_value=300, max_value=900, value=619) |
| age = st.number_input("Age", min_value=18, max_value=100, value=42) |
| tenure = st.number_input("Tenure (Years)", min_value=0, max_value=10, value=2) |
| balance = st.number_input("Balance", value=0.00, format="%.2f") |
| num_of_products = st.number_input("Number of Products", min_value=1, max_value=4, value=1) |
| estimated_salary = st.number_input("Estimated Salary", value=101348.88, format="%.2f") |
|
|
| predict_button = st.button("Calculate Risk", type="primary") |
|
|
| with col2: |
| st.header("Prediction Result") |
| st.caption("Model's churn probability prediction.") |
|
|
| if predict_button: |
| if model is None or scaler is None: |
| st.error("Models could not be loaded. Please contact administrator.") |
| else: |
| try: |
| |
| raw_data_dict = { |
| 'CreditScore': credit_score, |
| 'Age': age, |
| 'Tenure': tenure, |
| 'Balance': balance, |
| 'NumOfProducts': float(num_of_products), |
| 'HasCrCard': 1.0 if has_cr_card == "Yes" else 0.0, |
| 'IsActiveMember': 1.0 if is_active_member == "Yes" else 0.0, |
| 'EstimatedSalary': estimated_salary, |
| 'Geography_Germany': 1.0 if geography == "Germany" else 0.0, |
| 'Geography_Spain': 1.0 if geography == "Spain" else 0.0, |
| 'Gender_Male': 1.0 if gender == "Male" else 0.0 |
| } |
| |
| |
| raw_input_features = [raw_data_dict[col] for col in MODEL_INPUT_COLUMNS_ORDER] |
| raw_input_array = np.array(raw_input_features).reshape(1, -1) |
| |
| |
| scaled_input_array = scaler.transform(raw_input_array) |
| |
| |
| with st.spinner("Model running, making prediction..."): |
| prediction_proba = model.predict(scaled_input_array)[0][0] |
| churn_probability_percent = prediction_proba * 100 |
| threshold = 50.0 |
|
|
| |
| if churn_probability_percent > threshold: |
| st.error(f"Customer Churn Probability: {churn_probability_percent:.2f}%") |
| st.warning("This customer has HIGH churn risk. 🚨") |
| else: |
| st.success(f"Customer Churn Probability: {churn_probability_percent:.2f}%") |
| st.info("This customer has LOW churn risk. ✅") |
| |
| with st.expander("Processed (Scaled) Data Seen by Model"): |
| scaled_features_list = scaled_input_array.flatten().tolist() |
| st.json({col: val for col, val in zip(MODEL_INPUT_COLUMNS_ORDER, scaled_features_list)}) |
| |
| with st.expander("Raw Data Entered to Scaler (for verification)"): |
| st.json({col: val for col, val in zip(MODEL_INPUT_COLUMNS_ORDER, raw_input_features)}) |
|
|
| except Exception as e: |
| st.error(f"An error occurred during prediction: {e}") |
|
|