Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import joblib | |
| import pandas as pd | |
| from PIL import Image | |
| # Load the model and image | |
| def load_model(): | |
| return joblib.load("best_model.pkl") | |
| def load_roc_image(): | |
| return Image.open("roc_curve_rf_tuned.png") | |
| try: | |
| best_model = load_model() | |
| roc_img = load_roc_image() | |
| except Exception as e: | |
| st.error(f"Error loading model or image: {str(e)}") | |
| st.stop() | |
| # App title and description | |
| st.title("Customer Churn Prediction") | |
| st.write("Enter customer information to predict likelihood of churn") | |
| # Create two columns for inputs | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| age = st.slider("Age", min_value=18, max_value=100, value=40) | |
| gender = st.selectbox("Gender", options=["Male", "Female"]) | |
| tenure = st.slider("Tenure (months)", min_value=1, max_value=60, value=30) | |
| usage_frequency = st.slider("Usage Frequency", min_value=1, max_value=30, value=15) | |
| support_calls = st.slider("Support Calls", min_value=0, max_value=10, value=4) | |
| with col2: | |
| payment_delay = st.slider("Payment Delay", min_value=0, max_value=30, value=15) | |
| last_interaction = st.slider("Last Interaction (days ago)", min_value=1, max_value=30, value=15) | |
| total_spend = st.slider("Total Spend", min_value=100, max_value=1000, value=620) | |
| subscription_type = st.selectbox("Subscription Type", options=["Premium", "Standard", "Basic"]) | |
| contract_length = st.selectbox("Contract Length", options=["Monthly", "Quarterly", "Annual"]) | |
| # Prediction function | |
| def make_prediction(): | |
| input_data = { | |
| "Age": age, | |
| "Gender_Male": 1 if gender == "Male" else 0, | |
| "Gender_Female": 1 if gender == "Female" else 0, | |
| "Usage Frequency": usage_frequency, | |
| "Support Calls": support_calls, | |
| "Contract Length_Monthly": 1 if contract_length == "Monthly" else 0, | |
| "Contract Length_Quarterly": 1 if contract_length == "Quarterly" else 0, | |
| "Contract Length_Annual": 1 if contract_length == "Annual" else 0, | |
| "Payment Delay": payment_delay, | |
| "Last Interaction": last_interaction, | |
| "Total Spend": total_spend, | |
| "Tenure": tenure, | |
| "Subscription Type_Basic": 1 if subscription_type == "Basic" else 0, | |
| "Subscription Type_Premium": 1 if subscription_type == "Premium" else 0, | |
| "Subscription Type_Standard": 1 if subscription_type == "Standard" else 0, | |
| } | |
| input_df = pd.DataFrame([input_data]) | |
| # Predict churn and probability | |
| prediction = best_model.predict(input_df) | |
| prediction_proba = best_model.predict_proba(input_df)[:, 1] | |
| return prediction[0], prediction_proba[0] | |
| # Make prediction when button is clicked | |
| if st.button("Predict Churn"): | |
| try: | |
| prediction, probability = make_prediction() | |
| # Display results | |
| st.header("Prediction Results") | |
| # Create three columns for results | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Churn Prediction", "Yes" if prediction == 1 else "No") | |
| with col2: | |
| st.metric("Churn Probability", f"{probability:.2f}") | |
| with col3: | |
| risk_level = "High" if probability > 0.7 else ("Medium" if probability > 0.4 else "Low") | |
| st.metric("Risk Level", risk_level) | |
| # Display ROC curve | |
| st.subheader("Model ROC Curve") | |
| st.image(roc_img, caption="ROC Curve for Random Forest Model") | |
| except Exception as e: | |
| st.error(f"Error making prediction: {str(e)}") |