Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import shap | |
| import matplotlib.pyplot as plt | |
| from utils.data_loader import ( | |
| load_raw_data, get_train_test, get_onehot_train_test, | |
| get_encoded_data, CATEGORICAL_COLS, NUMERIC_COLS, | |
| ) | |
| from utils.models import ( | |
| train_tree_models, train_lr_model, train_nb_model, evaluate_model, | |
| get_shap_explainer, get_shap_single, | |
| ) | |
| from utils.visualizations import plot_roc_curves, plot_confusion_matrix, plot_gauge | |
| st.set_page_config(page_title="Churn Models", page_icon="π€", layout="wide") | |
| st.title("Churn Prediction Models") | |
| st.markdown("---") | |
| X_train, X_test, y_train, y_test, encoders, feature_cols = get_train_test() | |
| X_train_oh, X_test_oh, _, _, feature_cols_oh = get_onehot_train_test() | |
| # Initialize session state for models | |
| if "churn_models_trained" not in st.session_state: | |
| st.session_state.churn_models_trained = False | |
| st.session_state.all_models = None | |
| st.session_state.model_test_data = None | |
| st.session_state.metrics = None | |
| st.info( | |
| "**Encoding note:** Logistic Regression is trained on **One-Hot Encoded** features " | |
| f"({len(feature_cols_oh)} columns) while Random Forest and XGBoost use " | |
| f"**Label Encoding** ({len(feature_cols)} columns). " | |
| "Each model gets the encoding that's optimal for it." | |
| ) | |
| if not st.session_state.churn_models_trained: | |
| st.warning("β οΈ Models not trained yet. Click the button below to train all models.") | |
| if st.button("π Train All Models", type="primary", use_container_width=True): | |
| with st.spinner("Training models..."): | |
| tree_models = train_tree_models(X_train, y_train) | |
| lr_model = train_lr_model(X_train_oh, y_train) | |
| nb_model = train_nb_model(X_train, y_train) | |
| all_models = { | |
| "Logistic Regression": lr_model, | |
| "Naive Bayes": nb_model, | |
| } | |
| all_models.update(tree_models) | |
| model_test_data = { | |
| "Logistic Regression": X_test_oh, | |
| "Naive Bayes": X_test, | |
| "Random Forest": X_test, | |
| "XGBoost": X_test, | |
| } | |
| metrics = {} | |
| for name, model in all_models.items(): | |
| metrics[name] = evaluate_model(model, model_test_data[name], y_test) | |
| st.session_state.all_models = all_models | |
| st.session_state.model_test_data = model_test_data | |
| st.session_state.metrics = metrics | |
| st.session_state.churn_models_trained = True | |
| st.rerun() | |
| st.stop() | |
| # Models are trained - retrieve from session state | |
| all_models = st.session_state.all_models | |
| model_test_data = st.session_state.model_test_data | |
| metrics = st.session_state.metrics | |
| st.success("β All models trained and ready!") | |
| tab_how, tab_predict, tab_compare, tab_shap_global, tab_shap_individual, tab_whatif = st.tabs( | |
| ["How The Algorithms Work", "Predict and Compare", "Model Comparison", "Feature Importance (SHAP)", "Individual Explanations", "What-If Predictor"] | |
| ) | |
| # ββ Tab: How The Algorithms Work βββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_how: | |
| st.subheader("How Each Algorithm Works β A Visual Guide") | |
| st.markdown( | |
| "We use four very different algorithms to predict churn. Each approaches " | |
| "the problem in its own way. Understanding the differences helps explain " | |
| "why one model might outperform another and when predictions diverge." | |
| ) | |
| # ββ Logistic Regression ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### 1. Logistic Regression") | |
| st.markdown("*The simplest model β and often a strong baseline.*") | |
| st.graphviz_chart(""" | |
| digraph lr { | |
| rankdir=LR | |
| node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"] | |
| edge [color="#888888", penwidth=1.5] | |
| features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"] | |
| weighted [label="Weighted Sum\\nwβΓtenure + wβΓcharges\\n+ wβΓcontract + ...", fillcolor="#e0e7ff", color="#6366f1"] | |
| sigmoid [label="Sigmoid Function\\nSquash to 0-1", fillcolor="#fce7f3", color="#ec4899"] | |
| output [label="Probability\\n0.82 β Churn\\n0.15 β Retain", fillcolor="#d1fae5", color="#10b981"] | |
| features -> weighted -> sigmoid -> output | |
| } | |
| """) | |
| import plotly.graph_objects as go | |
| x_sig = np.linspace(-8, 8, 200) | |
| y_sig = 1 / (1 + np.exp(-x_sig)) | |
| fig_sig = go.Figure() | |
| fig_sig.add_trace(go.Scatter(x=x_sig, y=y_sig, mode="lines", line=dict(color="#6366f1", width=3), name="Sigmoid")) | |
| fig_sig.add_hline(y=0.5, line_dash="dash", line_color="gray", annotation_text="Decision boundary (0.5)") | |
| fig_sig.add_vrect(x0=-8, x1=0, fillcolor="#d1fae5", opacity=0.15, annotation_text="Retain", annotation_position="bottom left") | |
| fig_sig.add_vrect(x0=0, x1=8, fillcolor="#fee2e2", opacity=0.15, annotation_text="Churn", annotation_position="bottom right") | |
| fig_sig.update_layout( | |
| title="The Sigmoid Curve β Turns Any Number into a Probability", | |
| xaxis_title="Weighted Sum of Features (z = wβxβ + wβxβ + ...)", | |
| yaxis_title="Churn Probability", | |
| height=350, | |
| yaxis=dict(range=[0, 1]), | |
| ) | |
| st.plotly_chart(fig_sig, use_container_width=True) | |
| st.markdown( | |
| """ | |
| **How it works:** | |
| 1. Each customer feature gets a **weight** (a number the model learns). For example, | |
| "month-to-month contract" might get a high positive weight (pushes toward churn), | |
| while "long tenure" gets a negative weight (pushes away from churn). | |
| 2. The model adds up: `weightβ Γ featureβ + weightβ Γ featureβ + ...` | |
| 3. This sum could be any number (-β to +β). The **sigmoid function** (the S-curve above) | |
| squashes it into a probability between 0 and 1. | |
| 4. If the probability > 0.5 β predict "Churn". Otherwise β "Retain". | |
| **Why we use One-Hot Encoding for this model:** Because the weighted sum treats numbers | |
| at face value. If we encoded "Month-to-month"=2 and "Two year"=0, the model would | |
| think month-to-month is "more" of something than two-year β which is nonsensical. | |
| One-Hot avoids this by giving each category its own yes/no column. | |
| **Strengths:** Fast, interpretable (weights directly tell you what matters), good baseline. | |
| **Weaknesses:** Assumes a linear relationship between features and the log-odds of churn. | |
| Can't capture complex interactions (e.g., "fiber optic is only risky for short-tenure customers"). | |
| """ | |
| ) | |
| with st.expander("Step-by-step numerical walkthrough β Logistic Regression", expanded=False): | |
| st.markdown( | |
| """ | |
| Let's trace through exactly how the model makes one prediction. | |
| **A single customer:** | |
| | Feature | Value | | |
| |---|---| | |
| | tenure | 3 months | | |
| | MonthlyCharges | $85 | | |
| | Contract_Month-to-month | 1 (one-hot) | | |
| | Contract_One year | 0 | | |
| | Contract_Two year | 0 | | |
| | InternetService_Fiber optic | 1 | | |
| | Partner | 0 (no) | | |
| (After one-hot encoding and scaling) | |
| --- | |
| **Step 1 β The model has learned weights (one per feature)** | |
| | Feature | Learned Weight | Meaning | | |
| |---|---|---| | |
| | tenure (scaled) | -1.2 | Longer tenure β less churn | | |
| | MonthlyCharges (scaled) | +0.5 | Higher charges β more churn | | |
| | Contract_Month-to-month | +1.4 | Month-to-month β high churn risk | | |
| | Contract_Two year | -0.9 | Two-year contract β protective | | |
| | InternetService_Fiber optic | +0.8 | Fiber optic β correlated with churn | | |
| | Partner | -0.3 | Having a partner β slightly protective | | |
| | (bias term) | -0.2 | Baseline offset | | |
| --- | |
| **Step 2 β Compute the weighted sum (z)** | |
| """ | |
| ) | |
| st.latex(r"z = b + w_1 x_1 + w_2 x_2 + w_3 x_3 + \ldots") | |
| st.markdown( | |
| """ | |
| For our customer (tenure scaled to -1.5 since 3 months is below average): | |
| `z = -0.2 + (-1.2 Γ -1.5) + (0.5 Γ 0.8) + (1.4 Γ 1) + (0.8 Γ 1) + (-0.3 Γ 0)` | |
| `z = -0.2 + 1.8 + 0.4 + 1.4 + 0.8 + 0` | |
| `z = 4.2` | |
| --- | |
| **Step 3 β Apply the sigmoid function** | |
| """ | |
| ) | |
| st.latex(r"P(\text{churn}) = \frac{1}{1 + e^{-z}} = \frac{1}{1 + e^{-4.2}} = \frac{1}{1 + 0.015} = 0.985") | |
| st.markdown( | |
| """ | |
| **Probability = 98.5%** β Predict **Churn** | |
| --- | |
| **Step 4 β How did the model learn these weights?** | |
| | Epoch | What happens | | |
| |---|---| | |
| | Start | All weights = 0 (or random small numbers) | | |
| | Each step | Pick a customer, predict, calculate error | | |
| | Update | Adjust weights using gradient descent | | |
| | | `weight = weight - learning_rate Γ gradient` | | |
| | Repeat | Until weights converge (error stops decreasing) | | |
| The gradient tells the model: "This weight should go up" or "down" based on | |
| whether increasing it would reduce the prediction error. | |
| After training on all ~5,600 training customers for many iterations, | |
| the weights settle to values that best separate churners from non-churners. | |
| --- | |
| **Why this customer scores so high:** | |
| | Feature | Contribution to z | | |
| |---|---| | |
| | Short tenure (-1.5 scaled) | +1.8 (weight is negative, value is negative β positive contribution) | | |
| | Month-to-month contract | +1.4 | | |
| | Fiber optic | +0.8 | | |
| | High charges | +0.4 | | |
| | **Total push toward churn** | **+4.2** β sigmoid β 98.5% | | |
| Every feature contributes additively. That's the key property (and limitation) | |
| of logistic regression β it can't model interactions like "fiber optic is | |
| only risky for short-tenure customers." | |
| """ | |
| ) | |
| # ββ Naive Bayes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### 2. Naive Bayes (Gaussian)") | |
| st.markdown("*A probabilistic model that assumes features are independent.*") | |
| st.graphviz_chart(""" | |
| digraph nb { | |
| rankdir=LR | |
| node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"] | |
| edge [color="#888888", penwidth=1.5] | |
| features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"] | |
| prob [label="Calculate Probability\\nfor each class\\nusing Bayes' Theorem", fillcolor="#e0e7ff", color="#6366f1"] | |
| compare [label="Compare\\nP(Churn | features)\\nvs\\nP(Retain | features)", fillcolor="#fce7f3", color="#ec4899"] | |
| output [label="Pick Higher\\nProbability", fillcolor="#d1fae5", color="#10b981"] | |
| features -> prob -> compare -> output | |
| } | |
| """) | |
| st.markdown( | |
| """ | |
| **How it works:** | |
| Naive Bayes uses **Bayes' Theorem** to calculate the probability of churn given the customer's features. | |
| The "naive" part means it assumes all features are **independent** β knowing tenure doesn't tell | |
| you anything about monthly charges (which isn't true in reality, but the assumption simplifies the math). | |
| **Formula:** | |
| """ | |
| ) | |
| st.latex(r"P(\text{Churn} \mid \text{features}) = \frac{P(\text{features} \mid \text{Churn}) \times P(\text{Churn})}{P(\text{features})}") | |
| st.markdown( | |
| """ | |
| For continuous features (like tenure, charges), Gaussian Naive Bayes assumes each feature | |
| follows a **normal distribution** within each class. | |
| **Step-by-step example:** | |
| Suppose we want to predict if Customer X will churn. They have: | |
| - Tenure = 12 months | |
| - Monthly Charges = $70 | |
| **Step 1:** Calculate P(Tenure=12 | Churn) and P(Tenure=12 | Retain) | |
| From training data, we know: | |
| - Churners have average tenure = 18 months, std = 15 | |
| - Retainers have average tenure = 38 months, std = 20 | |
| Using the Gaussian (bell curve) formula: | |
| - P(Tenure=12 | Churn) = 0.024 (12 is close to churner average) | |
| - P(Tenure=12 | Retain) = 0.008 (12 is far from retainer average) | |
| **Step 2:** Do the same for Monthly Charges | |
| - P(Charges=70 | Churn) = 0.015 | |
| - P(Charges=70 | Retain) = 0.012 | |
| **Step 3:** Multiply probabilities (the "naive" independence assumption) | |
| - P(features | Churn) = 0.024 Γ 0.015 = 0.00036 | |
| - P(features | Retain) = 0.008 Γ 0.012 = 0.000096 | |
| **Step 4:** Apply Bayes' Theorem | |
| - P(Churn | features) β 0.00036 Γ P(Churn) = 0.00036 Γ 0.27 = 0.0000972 | |
| - P(Retain | features) β 0.000096 Γ P(Retain) = 0.000096 Γ 0.73 = 0.00007008 | |
| Normalize: P(Churn) = 0.0000972 / (0.0000972 + 0.00007008) = **58%** | |
| **Prediction: Churn (58% probability)** | |
| **Why Naive Bayes is different:** | |
| - Makes strong independence assumption (features don't interact) | |
| - Very fast to train and predict | |
| - Works well when features are actually somewhat independent | |
| - Often gives different predictions than tree-based models because it doesn't | |
| capture feature interactions (e.g., "fiber optic + short tenure" combo) | |
| """ | |
| ) | |
| # ββ Random Forest ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### 3. Random Forest") | |
| st.markdown("*An ensemble of decision trees that vote together.*") | |
| st.graphviz_chart(""" | |
| digraph rf { | |
| node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"] | |
| edge [color="#888888", penwidth=1.5] | |
| data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"] | |
| subgraph cluster_trees { | |
| label="200 Decision Trees (each sees a random subset)" | |
| style=dashed | |
| color="#94a3b8" | |
| fontname="Helvetica" | |
| fontsize=10 | |
| t1 [label="Tree 1\\nβ Churn", fillcolor="#d1fae5", color="#10b981"] | |
| t2 [label="Tree 2\\nβ Retain", fillcolor="#d1fae5", color="#10b981"] | |
| t3 [label="Tree 3\\nβ Churn", fillcolor="#d1fae5", color="#10b981"] | |
| dots [label="...\\n(197 more)", shape=plaintext] | |
| t200 [label="Tree 200\\nβ Churn", fillcolor="#d1fae5", color="#10b981"] | |
| } | |
| vote [label="Majority Vote\\n3 out of 4 say Churn\\nβ Final: Churn (75%)", fillcolor="#fef3c7", color="#f59e0b"] | |
| data -> t1 | |
| data -> t2 | |
| data -> t3 | |
| data -> t200 | |
| t1 -> vote | |
| t2 -> vote | |
| t3 -> vote | |
| t200 -> vote | |
| } | |
| """) | |
| st.markdown("**How a single decision tree works:**") | |
| st.graphviz_chart(""" | |
| digraph tree_example { | |
| node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.2,0.1"] | |
| edge [fontname="Helvetica", fontsize=9, color="#888888", penwidth=1.5] | |
| q1 [label="Contract = \\nMonth-to-month?", fillcolor="#e0e7ff", color="#6366f1"] | |
| q2 [label="tenure < 12\\nmonths?", fillcolor="#e0e7ff", color="#6366f1"] | |
| q3 [label="MonthlyCharges\\n> $70?", fillcolor="#e0e7ff", color="#6366f1"] | |
| churn1 [label="CHURN\\n(85% probability)", fillcolor="#fee2e2", color="#ef4444"] | |
| retain1 [label="RETAIN\\n(70% probability)", fillcolor="#d1fae5", color="#10b981"] | |
| churn2 [label="CHURN\\n(60% probability)", fillcolor="#fee2e2", color="#ef4444"] | |
| retain2 [label="RETAIN\\n(90% probability)", fillcolor="#d1fae5", color="#10b981"] | |
| q1 -> q2 [label=" Yes"] | |
| q1 -> q3 [label=" No"] | |
| q2 -> churn1 [label=" Yes"] | |
| q2 -> retain1 [label=" No"] | |
| q3 -> churn2 [label=" Yes"] | |
| q3 -> retain2 [label=" No"] | |
| } | |
| """) | |
| st.markdown( | |
| """ | |
| **How it works:** | |
| 1. Imagine asking a series of yes/no questions: "Is the contract month-to-month?" | |
| β "Is tenure less than 12 months?" β "Are charges above $70?" | |
| That's a **decision tree** β it keeps splitting until it reaches an answer. | |
| 2. A single tree is easy to overfit (it memorizes the training data too closely). | |
| So Random Forest builds **200 trees**, each trained on a **random sample** | |
| of the data and a **random subset** of features. | |
| 3. For a new customer, all 200 trees make their prediction and we take | |
| the **majority vote**. | |
| **Why "Random"?** Each tree only sees a random portion of the data and features. | |
| This diversity prevents the forest from over-relying on any single pattern. | |
| **Why Label Encoding is fine:** Trees split on thresholds (e.g., "is Contract < 1?"). | |
| They never multiply the encoded number β so the integer values don't introduce false relationships. | |
| **Strengths:** Handles non-linear patterns, resistant to overfitting, works with Label Encoding. | |
| **Weaknesses:** Slower than Logistic Regression, less interpretable (200 trees are hard to inspect by hand). | |
| """ | |
| ) | |
| with st.expander("Step-by-step numerical walkthrough β Random Forest", expanded=False): | |
| st.markdown( | |
| """ | |
| **Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month, | |
| Internet=Fiber optic, Partner=No. | |
| --- | |
| **Step 1 β Each tree sees different data** | |
| Random Forest creates 200 trees. Each tree gets: | |
| - A **bootstrap sample** β random selection of ~63% of training customers (with replacement) | |
| - A **random subset of features** at each split (e.g., 4 out of 19 features) | |
| | Tree | Training Customers | Features Available | | |
| |---|---|---| | |
| | Tree 1 | Customers #12, #45, #45, #78, #102, ... | tenure, Contract, Partner, TechSupport | | |
| | Tree 2 | Customers #3, #22, #56, #56, #89, ... | MonthlyCharges, InternetService, gender, SeniorCitizen | | |
| | Tree 3 | Customers #7, #33, #41, #67, #67, ... | tenure, PaymentMethod, MonthlyCharges, OnlineSecurity | | |
| | ... | ... | ... | | |
| Notice: Customer #45 appears twice in Tree 1 (bootstrap sampling), and each | |
| tree considers different features. | |
| --- | |
| **Step 2 β Each tree grows by finding the best splits** | |
| **Tree 1** (has tenure, Contract, Partner, TechSupport): | |
| The algorithm tries every possible split and picks the one that best | |
| separates churners from non-churners (measured by **Gini impurity**): | |
| """ | |
| ) | |
| st.latex(r"\text{Gini} = 1 - p_{\text{churn}}^2 - p_{\text{retain}}^2") | |
| st.markdown( | |
| """ | |
| | Candidate Split | Left Group (Churn%) | Right Group (Churn%) | Gini Improvement | | |
| |---|---|---|---| | |
| | Contract < 1 (Month-to-month) | 42% churn | 12% churn | **0.18** β best | | |
| | tenure < 12 | 38% churn | 20% churn | 0.14 | | |
| | Partner = 0 | 30% churn | 24% churn | 0.03 | | |
| Contract split wins β becomes the first question. | |
| Then each branch splits again on the remaining features, creating deeper | |
| questions. This continues until leaves are pure or a depth limit is reached. | |
| --- | |
| **Step 3 β Each tree predicts independently** | |
| For our customer (Month-to-month, tenure=3, Fiber optic, no Partner): | |
| | Tree | Path through the tree | Prediction | | |
| |---|---|---| | |
| | Tree 1 | Contract=M-t-m β tenure<12 β **CHURN** | Churn (92%) | | |
| | Tree 2 | Charges>$70 β Fiber optic β **CHURN** | Churn (78%) | | |
| | Tree 3 | tenure<6 β Charges>$60 β **CHURN** | Churn (85%) | | |
| | Tree 4 | Contract=M-t-m β No TechSupport β **CHURN** | Churn (88%) | | |
| | Tree 5 | Fiber optic β tenure<24 β **RETAIN** | Retain (55%) | | |
| | ... | ... | ... | | |
| --- | |
| **Step 4 β Majority vote** | |
| Out of 200 trees: | |
| - 172 trees predict **Churn** | |
| - 28 trees predict **Retain** | |
| `Probability = 172/200 = 86%` β Predict **Churn** | |
| Note: Tree 5 predicted Retain β that's fine. The diversity is intentional. | |
| If every tree agreed perfectly, there would be no benefit from having 200 of them. | |
| --- | |
| **Why Random Forest handles non-linear patterns:** | |
| A single tree can capture "Contract=Month-to-month AND tenure<12 β high churn" | |
| without needing to encode this interaction explicitly. | |
| The forest combines 200 different views of the data, each capturing | |
| different interaction patterns. The majority vote smooths out individual | |
| tree errors β this is why forests rarely overfit. | |
| """ | |
| ) | |
| # ββ XGBoost ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### 4. XGBoost (Extreme Gradient Boosting)") | |
| st.markdown("*The algorithm that wins most Kaggle competitions.*") | |
| st.graphviz_chart(""" | |
| digraph xgb { | |
| rankdir=TB | |
| node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"] | |
| edge [color="#888888", penwidth=1.5] | |
| data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"] | |
| t1 [label="Tree 1\\nLearns the main pattern\\n(e.g., contract type)", fillcolor="#d1fae5", color="#10b981"] | |
| e1 [label="Errors from Tree 1\\n(customers it got wrong)", fillcolor="#fee2e2", color="#ef4444"] | |
| t2 [label="Tree 2\\nFocuses on Tree 1's mistakes", fillcolor="#d1fae5", color="#10b981"] | |
| e2 [label="Remaining errors", fillcolor="#fee2e2", color="#ef4444"] | |
| t3 [label="Tree 3\\nFocuses on remaining mistakes", fillcolor="#d1fae5", color="#10b981"] | |
| dots [label="... (200 trees total, each fixing\\nthe previous tree's errors)", shape=plaintext, fontname="Helvetica"] | |
| final [label="Final Prediction\\nSum of all 200 trees\\n(each weighted by learning rate)", fillcolor="#fef3c7", color="#f59e0b"] | |
| data -> t1 -> e1 -> t2 -> e2 -> t3 -> dots -> final | |
| } | |
| """) | |
| st.markdown( | |
| """ | |
| **How it works:** | |
| 1. **Tree 1** tries to predict churn for all customers. It gets some right, some wrong. | |
| 2. **Tree 2** doesn't start from scratch β it specifically focuses on the customers | |
| that Tree 1 got **wrong**. It tries to correct those mistakes. | |
| 3. **Tree 3** focuses on the remaining mistakes from Tree 1 + Tree 2 combined. | |
| 4. This continues for 200 rounds. Each new tree is a specialist in fixing | |
| what all previous trees couldn't get right. | |
| 5. The final prediction is the **weighted sum** of all 200 trees. | |
| **The key difference from Random Forest:** | |
| - Random Forest: 200 trees trained **independently** (in parallel), then vote. | |
| - XGBoost: 200 trees trained **sequentially**, each one learning from the previous one's errors. | |
| **Why "Gradient Boosting"?** "Gradient" refers to the mathematical technique used to | |
| determine *how* each new tree should focus on errors. It's the same gradient descent | |
| concept used in deep learning. | |
| **Strengths:** Usually the most accurate model, handles complex non-linear patterns, | |
| built-in regularization prevents overfitting. | |
| **Weaknesses:** Slower to train, harder to interpret, more hyperparameters to tune. | |
| """ | |
| ) | |
| with st.expander("Step-by-step numerical walkthrough β XGBoost", expanded=False): | |
| st.markdown( | |
| """ | |
| **Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month, | |
| Internet=Fiber optic, Partner=No. | |
| --- | |
| **Step 1 β Start with a base prediction** | |
| Before any trees, XGBoost starts with the global average: | |
| `base prediction = overall churn rate = 0.265` (26.5%) | |
| Convert to log-odds: | |
| """ | |
| ) | |
| st.latex(r"\text{log-odds} = \ln\left(\frac{0.265}{1 - 0.265}\right) = \ln(0.361) = -1.02") | |
| st.markdown( | |
| """ | |
| Every customer starts at -1.02 (26.5% churn probability). | |
| --- | |
| **Step 2 β Tree 1 learns from the errors** | |
| For each customer, calculate the **residual** (how wrong the base prediction is): | |
| | Customer | Actual | Current Prediction | Residual | | |
| |---|---|---|---| | |
| | #1 | Churned (1) | 0.265 | +0.735 (prediction too low) | | |
| | #2 | Retained (0) | 0.265 | -0.265 (prediction too high) | | |
| | #3 | Churned (1) | 0.265 | +0.735 | | |
| | Our customer | Churned (1) | 0.265 | +0.735 | | |
| Tree 1 tries to predict these residuals (not the original labels). | |
| Tree 1 output for our customer: **+0.45** (it learned a partial correction). | |
| --- | |
| **Step 3 β Update prediction with learning rate** | |
| XGBoost uses a **learning rate** (Ξ· = 0.1 by default) to prevent overshooting: | |
| `new log-odds = -1.02 + (0.1 Γ 0.45) = -1.02 + 0.045 = -0.975` | |
| """ | |
| ) | |
| st.latex(r"P = \frac{1}{1 + e^{-(-0.975)}} = \frac{1}{1 + 2.65} = 0.274") | |
| st.markdown( | |
| """ | |
| Probability moved from 26.5% to 27.4%. A small step in the right direction. | |
| --- | |
| **Step 4 β Tree 2 focuses on remaining errors** | |
| New residuals (using updated predictions): | |
| | Customer | Actual | Updated Prediction | New Residual | | |
| |---|---|---|---| | |
| | Our customer | 1 | 0.274 | +0.726 (still too low) | | |
| Tree 2 learns these new residuals. Output for our customer: **+0.52** | |
| `log-odds = -0.975 + (0.1 Γ 0.52) = -0.975 + 0.052 = -0.923` | |
| `P = 28.4%` | |
| --- | |
| **Step 5 β Continue for 200 trees** | |
| | After Tree | Log-odds | Churn Probability | | |
| |---|---|---| | |
| | Base (no trees) | -1.02 | 26.5% | | |
| | Tree 1 | -0.975 | 27.4% | | |
| | Tree 2 | -0.923 | 28.4% | | |
| | Tree 10 | -0.42 | 39.7% | | |
| | Tree 50 | 0.85 | 70.1% | | |
| | Tree 100 | 1.72 | 84.8% | | |
| | Tree 200 | 2.95 | 95.0% | | |
| Each tree adds a small correction. After 200 trees, the probability | |
| climbed from 26.5% to 95.0%. | |
| --- | |
| **The final prediction formula:** | |
| """ | |
| ) | |
| st.latex(r"\text{prediction} = \text{base} + \eta \cdot f_1(x) + \eta \cdot f_2(x) + \ldots + \eta \cdot f_{200}(x)") | |
| st.markdown( | |
| """ | |
| Where each f(x) is a small decision tree focused on the remaining errors. | |
| --- | |
| **Key differences from Random Forest:** | |
| | | Random Forest | XGBoost | | |
| |---|---|---| | |
| | **Trees learn** | Independently (in parallel) | Sequentially (each from previous errors) | | |
| | **Each tree predicts** | The original label (Churn/Retain) | The **residual error** from all previous trees | | |
| | **Learning rate** | Not applicable | Controls step size (Ξ·=0.1 is typical) | | |
| | **Why it helps** | Diversity from random sampling | Each tree is a specialist in remaining errors | | |
| | **Risk** | Hard to overfit | Can overfit if too many trees or learning rate too high | | |
| The learning rate is crucial: without it (Ξ·=1.0), early trees would over-correct | |
| and the model would overfit. Small steps (Ξ·=0.1) force the model to build | |
| gradually, which produces better generalization. | |
| """ | |
| ) | |
| # ββ SHAP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### 4. SHAP β How We Explain Predictions") | |
| st.markdown("*Making the black box transparent.*") | |
| st.graphviz_chart(""" | |
| digraph shap { | |
| rankdir=LR | |
| node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"] | |
| edge [color="#888888", penwidth=1.5] | |
| base [label="Base Rate\\n26.5% churn\\n(average customer)", fillcolor="#e0e7ff", color="#6366f1"] | |
| f1 [label="Contract =\\nMonth-to-month\\n+18%", fillcolor="#fee2e2", color="#ef4444"] | |
| f2 [label="tenure =\\n2 months\\n+12%", fillcolor="#fee2e2", color="#ef4444"] | |
| f3 [label="TechSupport =\\nNo\\n+5%", fillcolor="#fee2e2", color="#ef4444"] | |
| f4 [label="TotalCharges =\\n$150 (low)\\n+3%", fillcolor="#fee2e2", color="#ef4444"] | |
| f5 [label="Partner = Yes\\n-4%", fillcolor="#d1fae5", color="#10b981"] | |
| pred [label="Final Prediction\\n60.5% churn", fillcolor="#fef3c7", color="#f59e0b"] | |
| base -> f1 -> f2 -> f3 -> f4 -> f5 -> pred | |
| } | |
| """) | |
| st.markdown( | |
| """ | |
| **How SHAP works:** | |
| Every prediction starts from the **base rate** β the overall churn rate in the data (~26.5%). | |
| Then SHAP shows how each feature **pushes** that prediction up or down: | |
| - Month-to-month contract β pushes probability **up** (toward churn) | |
| - Low tenure β pushes probability **up** (new customer, high risk) | |
| - Having a partner β pushes probability **down** (slightly protective) | |
| - Each feature gets a + or - contribution, and they all add up to the final prediction. | |
| **Why this matters for business:** If the model predicts an 80% churn probability, | |
| SHAP tells you *why* β "It's mainly because they're on a month-to-month contract | |
| and only been with us for 2 months." That's actionable: offer them a yearly contract | |
| with a discount. | |
| **The name "SHAP"** stands for SHapley Additive exPlanations, based on Shapley values | |
| from game theory β a mathematically rigorous way to fairly distribute credit among features. | |
| """ | |
| ) | |
| with st.expander("Step-by-step numerical walkthrough β SHAP values", expanded=False): | |
| st.markdown( | |
| """ | |
| SHAP uses a concept from game theory: **Shapley values**. The idea is to | |
| fairly distribute credit among "players" (features) for the "payout" (prediction). | |
| **The question:** XGBoost predicts 95% churn for our customer. But *which features | |
| caused this?* And by how much? | |
| --- | |
| **Step 1 β Start from the base value** | |
| The base value is the average model output across all training customers: | |
| `base value = 26.5%` (overall churn rate) | |
| --- | |
| **Step 2 β Measure each feature's marginal contribution** | |
| For each feature, SHAP asks: "If I remove this feature, how much does the | |
| prediction change?" But it's more sophisticated β it checks this across | |
| **every possible combination** of other features. | |
| **Example for "Contract = Month-to-month":** | |
| | Features included (besides Contract) | Prediction with Contract | Prediction without | Contribution of Contract | | |
| |---|---|---|---| | |
| | None | 45% | 26.5% | +18.5% | | |
| | tenure only | 52% | 35% | +17% | | |
| | tenure + MonthlyCharges | 68% | 48% | +20% | | |
| | tenure + MonthlyCharges + Internet | 82% | 60% | +22% | | |
| | All features | 95% | 73% | +22% | | |
| The SHAP value for Contract = **average of all marginal contributions** | |
| = (18.5 + 17 + 20 + 22 + 22) / 5 = **+19.9%** | |
| --- | |
| **Step 3 β Do this for every feature** | |
| | Feature | SHAP Value | Direction | | |
| |---|---|---| | |
| | Contract = Month-to-month | +19.9% | Pushes toward churn | | |
| | tenure = 3 months | +15.2% | Pushes toward churn | | |
| | InternetService = Fiber optic | +11.8% | Pushes toward churn | | |
| | MonthlyCharges = $85 | +8.4% | Pushes toward churn | | |
| | TechSupport = No | +6.1% | Pushes toward churn | | |
| | OnlineSecurity = No | +4.2% | Pushes toward churn | | |
| | Partner = No | +2.1% | Pushes toward churn | | |
| | PaymentMethod = Electronic check | +1.3% | Pushes toward churn | | |
| | Other features | -0.5% | Small protective effects | | |
| --- | |
| **Step 4 β Everything adds up to the prediction** | |
| `26.5% (base) + 19.9% + 15.2% + 11.8% + 8.4% + 6.1% + 4.2% + 2.1% + 1.3% - 0.5% = 95.0%` | |
| This is guaranteed by the Shapley value properties β the contributions | |
| always sum exactly to the difference between the base value and the prediction. | |
| --- | |
| **Why Shapley values are "fair":** | |
| Shapley values are the *only* method that satisfies all four fairness properties: | |
| 1. **Efficiency** β contributions sum to the total prediction | |
| 2. **Symmetry** β features with identical effects get identical values | |
| 3. **Null player** β irrelevant features get zero contribution | |
| 4. **Linearity** β the method is consistent across combined models | |
| This is why SHAP is preferred over simpler methods like "feature importance" β | |
| it's mathematically guaranteed to distribute credit fairly. | |
| """ | |
| ) | |
| # ββ Comparison Table βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### Quick Comparison") | |
| st.markdown( | |
| """ | |
| | | Logistic Regression | Random Forest | XGBoost | | |
| |---|---|---|---| | |
| | **How it learns** | Finds the best weights for a linear equation | Builds many independent trees and averages their votes | Builds trees sequentially, each correcting the last | | |
| | **Encoding** | One-Hot (needs binary columns) | Label (integers are fine) | Label (integers are fine) | | |
| | **Speed** | Very fast | Moderate | Slower | | |
| | **Accuracy** | Good baseline | Very good | Usually best | | |
| | **Interpretability** | High (weights = feature importance) | Medium (feature importance available) | Medium (needs SHAP for full explanation) | | |
| | **Best for** | Simple, linear relationships | Non-linear patterns with moderate data | Complex patterns, competitions, production systems | | |
| """ | |
| ) | |
| # ββ Why Each Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| st.markdown("---") | |
| st.markdown("### Why We Chose Each Model") | |
| st.markdown( | |
| """ | |
| #### Logistic Regression β The Baseline | |
| | | | | |
| |---|---| | |
| | **Why we included it** | Every ML project needs a simple baseline to compare against. If a complex model can't beat Logistic Regression, the complexity isn't justified. | | |
| | **Where it's used in industry** | Credit scoring (banks are required to use interpretable models), medical diagnosis screening, spam detection, ad click prediction at scale. | | |
| | **Key benefit** | Full transparency β every feature gets a weight you can show to stakeholders. "Month-to-month contract increases churn probability by X%." | | |
| | **Limitation** | Assumes each feature contributes independently. Can't learn "fiber optic is only risky when tenure is short" without manually creating that interaction feature. | | |
| | **When to pick this** | When you need to explain *why* to regulators or non-technical stakeholders, or when you need a fast model that runs on millions of rows per second. | | |
| --- | |
| #### Naive Bayes β The Probabilistic Baseline | |
| | | | | |
| |---|---| | |
| | **Why we included it** | Provides a fundamentally different approach based on probability theory. Makes the independence assumption explicit, which helps identify when features actually interact. | | |
| | **Where it's used in industry** | Email spam filtering (Gmail), document classification, real-time sentiment analysis, medical diagnosis screening. | | |
| | **Key benefit** | Extremely fast training and prediction. Works well with small datasets. Naturally outputs calibrated probabilities. | | |
| | **Limitation** | The independence assumption is often violated (e.g., contract type and monthly charges are correlated). Can't capture feature interactions. | | |
| | **When to pick this** | When speed matters more than accuracy, or as a baseline to test if feature interactions are important (if NB matches complex models, interactions don't matter). | | |
| --- | |
| #### Random Forest β The Reliable Workhorse | |
| | | | | |
| |---|---| | |
| | **Why we included it** | It captures non-linear patterns and feature interactions that Logistic Regression and Naive Bayes miss, without needing manual feature engineering. | | |
| | **Where it's used in industry** | Fraud detection (PayPal), customer segmentation, manufacturing defect prediction, insurance risk assessment. | | |
| | **Key benefit** | Robust to outliers and noise. Rarely overfits. Handles mixed feature types without scaling or encoding concerns. | | |
| | **Limitation** | Slower inference than Logistic Regression or Naive Bayes. The 200-tree ensemble is harder to explain to non-technical audiences. | | |
| | **When to pick this** | When you need reliable accuracy with minimal tuning. It's the "safe choice" β almost always performs well without surprises. | | |
| --- | |
| #### XGBoost β The Performance Leader | |
| | | | | |
| |---|---| | |
| | **Why we included it** | It's consistently the top-performing algorithm for structured/tabular data like our customer dataset. | | |
| | **Where it's used in industry** | Won most Kaggle competitions (2015βpresent). Used at Airbnb (search ranking), Uber (ETA prediction), major banks (credit risk). | | |
| | **Key benefit** | Sequential error correction means it learns from its own mistakes. Built-in regularization prevents overfitting. | | |
| | **Limitation** | More hyperparameters to tune. Slower to train than Random Forest. Requires SHAP for interpretability. | | |
| | **When to pick this** | When predictive accuracy is the priority and you have time to tune. For production churn models, XGBoost is typically the default choice. | | |
| --- | |
| #### SHAP β The Explainability Layer | |
| | | | | |
| |---|---| | |
| | **Why we included it** | A prediction without an explanation has no business value. SHAP makes any model interpretable. | | |
| | **Where it's used in industry** | Required by EU regulations (GDPR "right to explanation"). Used at all major banks for credit decisions. Standard at tech companies for model debugging. | | |
| | **Key benefit** | Works with *any* model. Gives both global importance (which features matter overall) and local explanations (why *this* customer was flagged). | | |
| | **When to use** | Always. SHAP should be part of every production ML system. | | |
| --- | |
| #### How They Complement Each Other | |
| | Role | Model | Why | | |
| |---|---|---| | |
| | **Linear baseline** | Logistic Regression | Sets the floor β any model must beat this to justify its complexity | | |
| | **Probabilistic baseline** | Naive Bayes | Tests if feature independence assumption holds. If NB performs poorly, we know interactions matter. | | |
| | **Production model** | XGBoost | Highest accuracy for deployment | | |
| | **Backup / ensemble** | Random Forest | If XGBoost overfits on new data, RF is a stable alternative | | |
| | **Explainability** | SHAP on XGBoost | Turns the best model into something stakeholders can act on | | |
| | **Live updates** | SGDClassifier (Live page) | The only model supporting `partial_fit` for incremental learning | | |
| In a real deployment, you would use XGBoost as the primary model with SHAP | |
| for explanations, keep Random Forest as a monitoring baseline, compare against | |
| Naive Bayes to validate that feature interactions are being captured, and use the | |
| SGDClassifier pattern for real-time drift adaptation. | |
| --- | |
| #### Why Logistic Regression Can Match or Beat Complex Models on This Dataset | |
| If Logistic Regression performs as well as XGBoost here, that's an important | |
| finding β not a flaw. | |
| | Factor | Why It Helps LR | | |
| |---|---| | |
| | **Small dataset (~5,600 training rows)** | XGBoost has far more parameters (200 trees Γ many splits) than LR (~20 weights). With limited data, complex models risk overfitting. LR's simplicity is an advantage. | | |
| | **Linear churn signal** | Contract type, tenure, and charges have straightforward relationships with churn. These are essentially linear effects β exactly what LR is designed for. | | |
| | **One-hot encoding** | LR receives one-hot encoded features (each category as its own binary column). This avoids false ordinal relationships that label encoding introduces for linear models. | | |
| | **Strong regularization** | LR's built-in L2 regularization prevents any single feature from dominating, keeping the model stable. | | |
| **What this means in practice:** A simpler, fully interpretable model achieving | |
| top accuracy is a **win**. Every prediction can be explained to the retention team, | |
| the model is faster to deploy, and regulatory compliance is easier. | |
| **When would XGBoost pull ahead?** With a much larger dataset (100K+ customers) | |
| and richer features (browsing behavior, support call logs, time-series usage patterns), | |
| XGBoost would exploit non-linear interactions that LR can't capture β like | |
| "fiber optic is only risky for short-tenure customers without tech support." | |
| """ | |
| ) | |
| # ββ Tab 0: Predict and Compare βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_predict: | |
| st.subheader("Predict and Compare β Live Demo") | |
| st.markdown( | |
| "Below are **5 real customers** from the test set (data the models have never seen during training). " | |
| "Click **Run Predictions** to see what each model thinks β then compare against what actually happened." | |
| ) | |
| raw_df_pred = load_raw_data() | |
| churned_idx = y_test[y_test == 1].index[:3] | |
| retained_idx = y_test[y_test == 0].index[:2] | |
| demo_idx = churned_idx.tolist() + retained_idx.tolist() | |
| display_columns = [ | |
| "customerID", "gender", "SeniorCitizen", "tenure", "Contract", | |
| "InternetService", "MonthlyCharges", "TotalCharges", | |
| ] | |
| demo_display = raw_df_pred.loc[demo_idx, display_columns].copy() | |
| demo_display.index = range(1, len(demo_display) + 1) | |
| demo_display.index.name = "#" | |
| st.markdown("#### Customer Profiles") | |
| st.dataframe(demo_display, use_container_width=True) | |
| if st.button("Run Predictions", type="primary", use_container_width=True): | |
| st.markdown("---") | |
| st.markdown("#### Prediction Results") | |
| model_names = list(all_models.keys()) | |
| # Display column headers | |
| header_cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5]) | |
| header_cols[0].markdown("**#**") | |
| header_cols[1].markdown("**Customer ID**") | |
| for j, mn in enumerate(model_names): | |
| header_cols[j + 2].markdown(f"**{mn}**") | |
| header_cols[len(model_names) + 2].markdown("**Actual**") | |
| header_cols[len(model_names) + 3].markdown("**β**") | |
| st.markdown("---") | |
| results_rows = [] | |
| for pos, idx in enumerate(demo_idx, 1): | |
| actual_val = y_test.loc[idx] | |
| actual_label = "Churned" if actual_val == 1 else "Retained" | |
| row_result = { | |
| "#": pos, | |
| "Customer ID": raw_df_pred.loc[idx, "customerID"], | |
| } | |
| all_correct = True | |
| all_wrong = True | |
| for model_name in model_names: | |
| model_obj = all_models[model_name] | |
| x_data = model_test_data[model_name] | |
| row_enc = x_data.loc[[idx]] | |
| pred = int(model_obj.predict(row_enc)[0]) | |
| proba = float(model_obj.predict_proba(row_enc)[0][1]) | |
| pred_label = "Churned" if pred == 1 else "Retained" | |
| correct = bool(pred == actual_val) | |
| if correct: | |
| all_wrong = False | |
| else: | |
| all_correct = False | |
| row_result[f"{model_name}"] = f"{pred_label} ({proba:.0%})" | |
| row_result[f"{model_name}_correct"] = correct | |
| row_result["Actual"] = actual_label | |
| if all_correct: | |
| row_result["_status"] = "all_correct" | |
| elif all_wrong: | |
| row_result["_status"] = "all_wrong" | |
| else: | |
| row_result["_status"] = "mixed" | |
| results_rows.append(row_result) | |
| results_df = pd.DataFrame(results_rows) | |
| for _, row in results_df.iterrows(): | |
| status = row["_status"] | |
| if status == "all_correct": | |
| icon = "β " | |
| elif status == "all_wrong": | |
| icon = "β" | |
| else: | |
| icon = "β οΈ" | |
| cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5]) | |
| cols[0].markdown(f"**{row['#']}**") | |
| cols[1].markdown(f"`{row['Customer ID']}`") | |
| for j, mn in enumerate(model_names): | |
| correct = row[f"{mn}_correct"] | |
| mark = "β" if correct else "β" | |
| cols[j + 2].markdown( | |
| f"{'π’' if correct else 'π΄'} {row[mn]} {mark}" | |
| ) | |
| cols[len(model_names) + 2].markdown(f"**{row['Actual']}**") | |
| cols[len(model_names) + 3].markdown(icon) | |
| st.markdown("---") | |
| n_total = len(results_rows) | |
| n_all_correct = sum(1 for r in results_rows if r["_status"] == "all_correct") | |
| n_all_wrong = sum(1 for r in results_rows if r["_status"] == "all_wrong") | |
| n_mixed = n_total - n_all_correct - n_all_wrong | |
| rc1, rc2, rc3 = st.columns(3) | |
| rc1.metric("All Models Correct", f"{n_all_correct} / {n_total}", delta="β ") | |
| rc2.metric("Mixed Results", f"{n_mixed} / {n_total}", delta="β οΈ" if n_mixed > 0 else None) | |
| rc3.metric("All Models Wrong", f"{n_all_wrong} / {n_total}", delta="β" if n_all_wrong > 0 else None, delta_color="inverse") | |
| st.markdown( | |
| "**Legend:** π’ = correct prediction, π΄ = wrong prediction. " | |
| "Percentage shown is the model's estimated churn probability." | |
| ) | |
| # ββ Tab 1: Model Comparison ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_compare: | |
| st.subheader("Performance Metrics") | |
| metrics_df = pd.DataFrame(metrics).T | |
| metrics_df = metrics_df.round(3) | |
| best_model = metrics_df["AUC"].idxmax() | |
| st.info(f"Best model by AUC: **{best_model}** ({metrics_df.loc[best_model, 'AUC']:.3f})") | |
| encoding_col = pd.Series({ | |
| "Logistic Regression": "One-Hot", | |
| "Random Forest": "Label", | |
| "XGBoost": "Label", | |
| }, name="Encoding") | |
| display_metrics = pd.concat([encoding_col, metrics_df], axis=1) | |
| st.dataframe( | |
| display_metrics.style.highlight_max(axis=0, subset=metrics_df.columns, color="#c6efce"), | |
| use_container_width=True, | |
| ) | |
| st.markdown("---") | |
| st.subheader("ROC Curves") | |
| roc_entries = [(name, all_models[name], model_test_data[name]) for name in all_models] | |
| st.plotly_chart(plot_roc_curves(roc_entries, y_test), use_container_width=True) | |
| st.markdown("---") | |
| st.subheader("Confusion Matrices") | |
| cm_cols = st.columns(4) | |
| for i, name in enumerate(all_models): | |
| with cm_cols[i]: | |
| y_pred = all_models[name].predict(model_test_data[name]) | |
| st.plotly_chart( | |
| plot_confusion_matrix(y_test, y_pred, title=name), | |
| use_container_width=True, | |
| ) | |
| # ββ Tab 2: Global SHAP ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_shap_global: | |
| st.subheader("Global Feature Importance β XGBoost") | |
| st.markdown("SHAP values show how much each feature pushes the prediction toward or away from churn.") | |
| xgb_model = all_models["XGBoost"] | |
| explainer, shap_values = get_shap_explainer(xgb_model, X_train) | |
| fig_bar, ax_bar = plt.subplots(figsize=(10, 6)) | |
| shap.plots.bar(shap_values, max_display=15, show=False, ax=ax_bar) | |
| st.pyplot(fig_bar) | |
| st.markdown("---") | |
| st.markdown("**Beeswarm Plot** β Each dot is a customer. Color = feature value (red = high, blue = low).") | |
| fig_bee, ax_bee = plt.subplots(figsize=(10, 8)) | |
| shap.plots.beeswarm(shap_values, max_display=15, show=False) | |
| st.pyplot(plt.gcf()) | |
| # ββ Tab 3: Individual Explanations βββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_shap_individual: | |
| st.subheader("Explain a Single Customer's Prediction") | |
| raw_df = load_raw_data() | |
| customer_ids = raw_df["customerID"].values | |
| selected_id = st.selectbox("Select Customer ID", customer_ids[:200]) | |
| idx_in_raw = raw_df[raw_df["customerID"] == selected_id].index[0] | |
| enc_df, _ = get_encoded_data() | |
| if idx_in_raw in X_test.index: | |
| row = X_test.loc[[idx_in_raw]] | |
| actual = y_test.loc[idx_in_raw] | |
| elif idx_in_raw in X_train.index: | |
| row = X_train.loc[[idx_in_raw]] | |
| actual = y_train.loc[idx_in_raw] | |
| else: | |
| row = enc_df.loc[[idx_in_raw], feature_cols] | |
| actual = enc_df.loc[idx_in_raw, "Churn"] | |
| proba = xgb_model.predict_proba(row)[0][1] | |
| c1, c2 = st.columns(2) | |
| c1.metric("Predicted Churn Probability", f"{proba:.1%}") | |
| c2.metric("Actual Outcome", "Churned" if actual == 1 else "Retained") | |
| st.markdown("**Customer Details (raw values):**") | |
| st.dataframe(raw_df[raw_df["customerID"] == selected_id].T, use_container_width=True) | |
| st.markdown("---") | |
| st.markdown("**SHAP Waterfall β Why this prediction?**") | |
| sv = get_shap_single(explainer, row) | |
| fig_wf, ax_wf = plt.subplots(figsize=(10, 6)) | |
| shap.plots.waterfall(sv[0], max_display=12, show=False) | |
| st.pyplot(plt.gcf()) | |
| # ββ Tab 4: What-If Predictor βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with tab_whatif: | |
| st.subheader("What-If Predictor") | |
| st.markdown("Adjust customer features and see how churn probability changes in real time.") | |
| raw_df = load_raw_data() | |
| wi_col1, wi_col2, wi_col3 = st.columns(3) | |
| with wi_col1: | |
| wi_gender = st.selectbox("Gender", ["Female", "Male"], key="wi_gender") | |
| wi_senior = st.selectbox("Senior Citizen", [0, 1], key="wi_senior") | |
| wi_partner = st.selectbox("Partner", ["Yes", "No"], key="wi_partner") | |
| wi_dependents = st.selectbox("Dependents", ["Yes", "No"], key="wi_dep") | |
| wi_tenure = st.slider("Tenure (months)", 0, 72, 12, key="wi_tenure") | |
| with wi_col2: | |
| wi_phone = st.selectbox("Phone Service", ["Yes", "No"], key="wi_phone") | |
| wi_multi = st.selectbox("Multiple Lines", ["No", "Yes", "No phone service"], key="wi_multi") | |
| wi_internet = st.selectbox("Internet Service", ["DSL", "Fiber optic", "No"], key="wi_inet") | |
| wi_security = st.selectbox("Online Security", ["Yes", "No", "No internet service"], key="wi_sec") | |
| wi_backup = st.selectbox("Online Backup", ["Yes", "No", "No internet service"], key="wi_bak") | |
| with wi_col3: | |
| wi_protection = st.selectbox("Device Protection", ["Yes", "No", "No internet service"], key="wi_prot") | |
| wi_support = st.selectbox("Tech Support", ["Yes", "No", "No internet service"], key="wi_sup") | |
| wi_tv = st.selectbox("Streaming TV", ["Yes", "No", "No internet service"], key="wi_tv") | |
| wi_movies = st.selectbox("Streaming Movies", ["Yes", "No", "No internet service"], key="wi_mov") | |
| wi_contract = st.selectbox("Contract", ["Month-to-month", "One year", "Two year"], key="wi_con") | |
| wi_col4, wi_col5 = st.columns(2) | |
| with wi_col4: | |
| wi_paperless = st.selectbox("Paperless Billing", ["Yes", "No"], key="wi_paper") | |
| wi_payment = st.selectbox( | |
| "Payment Method", | |
| ["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"], | |
| key="wi_pay", | |
| ) | |
| with wi_col5: | |
| wi_monthly = st.slider("Monthly Charges ($)", 18.0, 120.0, 70.0, step=0.5, key="wi_monthly") | |
| wi_total = st.slider("Total Charges ($)", 18.0, 9000.0, 1500.0, step=10.0, key="wi_total") | |
| input_dict = { | |
| "gender": wi_gender, "Partner": wi_partner, "Dependents": wi_dependents, | |
| "PhoneService": wi_phone, "MultipleLines": wi_multi, "InternetService": wi_internet, | |
| "OnlineSecurity": wi_security, "OnlineBackup": wi_backup, | |
| "DeviceProtection": wi_protection, "TechSupport": wi_support, | |
| "StreamingTV": wi_tv, "StreamingMovies": wi_movies, "Contract": wi_contract, | |
| "PaperlessBilling": wi_paperless, "PaymentMethod": wi_payment, | |
| } | |
| numeric_dict = { | |
| "SeniorCitizen": wi_senior, "tenure": wi_tenure, | |
| "MonthlyCharges": wi_monthly, "TotalCharges": wi_total, | |
| } | |
| _, enc_map = get_encoded_data() | |
| encoded_input = {} | |
| for col, val in input_dict.items(): | |
| le = enc_map[col] | |
| if val in le.classes_: | |
| encoded_input[col] = le.transform([val])[0] | |
| else: | |
| encoded_input[col] = 0 | |
| encoded_input.update(numeric_dict) | |
| input_row = pd.DataFrame([encoded_input])[feature_cols] | |
| wi_proba = xgb_model.predict_proba(input_row)[0][1] | |
| st.markdown("---") | |
| res_col1, res_col2 = st.columns([1, 2]) | |
| with res_col1: | |
| st.metric("Churn Probability", f"{wi_proba:.1%}") | |
| if wi_proba < 0.3: | |
| st.success("Low risk β customer likely to stay") | |
| elif wi_proba < 0.6: | |
| st.warning("Medium risk β consider retention offer") | |
| else: | |
| st.error("High risk β immediate intervention recommended") | |
| with res_col2: | |
| st.plotly_chart(plot_gauge(wi_proba), use_container_width=True) | |
| st.markdown("---") | |
| st.markdown("**Top Feature Drivers for This Configuration:**") | |
| sv_wi = get_shap_single(explainer, input_row) | |
| fig_wi, ax_wi = plt.subplots(figsize=(10, 5)) | |
| shap.plots.waterfall(sv_wi[0], max_display=10, show=False) | |
| st.pyplot(plt.gcf()) | |