ml-demo / pages /3_Churn_Models.py
aliarafat-stack-ml's picture
first commit
f3a6f24
import streamlit as st
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from utils.data_loader import (
load_raw_data, get_train_test, get_onehot_train_test,
get_encoded_data, CATEGORICAL_COLS, NUMERIC_COLS,
)
from utils.models import (
train_tree_models, train_lr_model, train_nb_model, evaluate_model,
get_shap_explainer, get_shap_single,
)
from utils.visualizations import plot_roc_curves, plot_confusion_matrix, plot_gauge
st.set_page_config(page_title="Churn Models", page_icon="πŸ€–", layout="wide")
st.title("Churn Prediction Models")
st.markdown("---")
X_train, X_test, y_train, y_test, encoders, feature_cols = get_train_test()
X_train_oh, X_test_oh, _, _, feature_cols_oh = get_onehot_train_test()
# Initialize session state for models
if "churn_models_trained" not in st.session_state:
st.session_state.churn_models_trained = False
st.session_state.all_models = None
st.session_state.model_test_data = None
st.session_state.metrics = None
st.info(
"**Encoding note:** Logistic Regression is trained on **One-Hot Encoded** features "
f"({len(feature_cols_oh)} columns) while Random Forest and XGBoost use "
f"**Label Encoding** ({len(feature_cols)} columns). "
"Each model gets the encoding that's optimal for it."
)
if not st.session_state.churn_models_trained:
st.warning("⚠️ Models not trained yet. Click the button below to train all models.")
if st.button("πŸš€ Train All Models", type="primary", use_container_width=True):
with st.spinner("Training models..."):
tree_models = train_tree_models(X_train, y_train)
lr_model = train_lr_model(X_train_oh, y_train)
nb_model = train_nb_model(X_train, y_train)
all_models = {
"Logistic Regression": lr_model,
"Naive Bayes": nb_model,
}
all_models.update(tree_models)
model_test_data = {
"Logistic Regression": X_test_oh,
"Naive Bayes": X_test,
"Random Forest": X_test,
"XGBoost": X_test,
}
metrics = {}
for name, model in all_models.items():
metrics[name] = evaluate_model(model, model_test_data[name], y_test)
st.session_state.all_models = all_models
st.session_state.model_test_data = model_test_data
st.session_state.metrics = metrics
st.session_state.churn_models_trained = True
st.rerun()
st.stop()
# Models are trained - retrieve from session state
all_models = st.session_state.all_models
model_test_data = st.session_state.model_test_data
metrics = st.session_state.metrics
st.success("βœ… All models trained and ready!")
tab_how, tab_predict, tab_compare, tab_shap_global, tab_shap_individual, tab_whatif = st.tabs(
["How The Algorithms Work", "Predict and Compare", "Model Comparison", "Feature Importance (SHAP)", "Individual Explanations", "What-If Predictor"]
)
# ── Tab: How The Algorithms Work ─────────────────────────────────────────────
with tab_how:
st.subheader("How Each Algorithm Works β€” A Visual Guide")
st.markdown(
"We use four very different algorithms to predict churn. Each approaches "
"the problem in its own way. Understanding the differences helps explain "
"why one model might outperform another and when predictions diverge."
)
# ── Logistic Regression ──────────────────────────────────────────────────
st.markdown("---")
st.markdown("### 1. Logistic Regression")
st.markdown("*The simplest model β€” and often a strong baseline.*")
st.graphviz_chart("""
digraph lr {
rankdir=LR
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"]
edge [color="#888888", penwidth=1.5]
features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"]
weighted [label="Weighted Sum\\nw₁×tenure + wβ‚‚Γ—charges\\n+ w₃×contract + ...", fillcolor="#e0e7ff", color="#6366f1"]
sigmoid [label="Sigmoid Function\\nSquash to 0-1", fillcolor="#fce7f3", color="#ec4899"]
output [label="Probability\\n0.82 β†’ Churn\\n0.15 β†’ Retain", fillcolor="#d1fae5", color="#10b981"]
features -> weighted -> sigmoid -> output
}
""")
import plotly.graph_objects as go
x_sig = np.linspace(-8, 8, 200)
y_sig = 1 / (1 + np.exp(-x_sig))
fig_sig = go.Figure()
fig_sig.add_trace(go.Scatter(x=x_sig, y=y_sig, mode="lines", line=dict(color="#6366f1", width=3), name="Sigmoid"))
fig_sig.add_hline(y=0.5, line_dash="dash", line_color="gray", annotation_text="Decision boundary (0.5)")
fig_sig.add_vrect(x0=-8, x1=0, fillcolor="#d1fae5", opacity=0.15, annotation_text="Retain", annotation_position="bottom left")
fig_sig.add_vrect(x0=0, x1=8, fillcolor="#fee2e2", opacity=0.15, annotation_text="Churn", annotation_position="bottom right")
fig_sig.update_layout(
title="The Sigmoid Curve β€” Turns Any Number into a Probability",
xaxis_title="Weighted Sum of Features (z = w₁x₁ + wβ‚‚xβ‚‚ + ...)",
yaxis_title="Churn Probability",
height=350,
yaxis=dict(range=[0, 1]),
)
st.plotly_chart(fig_sig, use_container_width=True)
st.markdown(
"""
**How it works:**
1. Each customer feature gets a **weight** (a number the model learns). For example,
"month-to-month contract" might get a high positive weight (pushes toward churn),
while "long tenure" gets a negative weight (pushes away from churn).
2. The model adds up: `weight₁ Γ— feature₁ + weightβ‚‚ Γ— featureβ‚‚ + ...`
3. This sum could be any number (-∞ to +∞). The **sigmoid function** (the S-curve above)
squashes it into a probability between 0 and 1.
4. If the probability > 0.5 β†’ predict "Churn". Otherwise β†’ "Retain".
**Why we use One-Hot Encoding for this model:** Because the weighted sum treats numbers
at face value. If we encoded "Month-to-month"=2 and "Two year"=0, the model would
think month-to-month is "more" of something than two-year β€” which is nonsensical.
One-Hot avoids this by giving each category its own yes/no column.
**Strengths:** Fast, interpretable (weights directly tell you what matters), good baseline.
**Weaknesses:** Assumes a linear relationship between features and the log-odds of churn.
Can't capture complex interactions (e.g., "fiber optic is only risky for short-tenure customers").
"""
)
with st.expander("Step-by-step numerical walkthrough β€” Logistic Regression", expanded=False):
st.markdown(
"""
Let's trace through exactly how the model makes one prediction.
**A single customer:**
| Feature | Value |
|---|---|
| tenure | 3 months |
| MonthlyCharges | $85 |
| Contract_Month-to-month | 1 (one-hot) |
| Contract_One year | 0 |
| Contract_Two year | 0 |
| InternetService_Fiber optic | 1 |
| Partner | 0 (no) |
(After one-hot encoding and scaling)
---
**Step 1 β€” The model has learned weights (one per feature)**
| Feature | Learned Weight | Meaning |
|---|---|---|
| tenure (scaled) | -1.2 | Longer tenure β†’ less churn |
| MonthlyCharges (scaled) | +0.5 | Higher charges β†’ more churn |
| Contract_Month-to-month | +1.4 | Month-to-month β†’ high churn risk |
| Contract_Two year | -0.9 | Two-year contract β†’ protective |
| InternetService_Fiber optic | +0.8 | Fiber optic β†’ correlated with churn |
| Partner | -0.3 | Having a partner β†’ slightly protective |
| (bias term) | -0.2 | Baseline offset |
---
**Step 2 β€” Compute the weighted sum (z)**
"""
)
st.latex(r"z = b + w_1 x_1 + w_2 x_2 + w_3 x_3 + \ldots")
st.markdown(
"""
For our customer (tenure scaled to -1.5 since 3 months is below average):
`z = -0.2 + (-1.2 Γ— -1.5) + (0.5 Γ— 0.8) + (1.4 Γ— 1) + (0.8 Γ— 1) + (-0.3 Γ— 0)`
`z = -0.2 + 1.8 + 0.4 + 1.4 + 0.8 + 0`
`z = 4.2`
---
**Step 3 β€” Apply the sigmoid function**
"""
)
st.latex(r"P(\text{churn}) = \frac{1}{1 + e^{-z}} = \frac{1}{1 + e^{-4.2}} = \frac{1}{1 + 0.015} = 0.985")
st.markdown(
"""
**Probability = 98.5%** β†’ Predict **Churn**
---
**Step 4 β€” How did the model learn these weights?**
| Epoch | What happens |
|---|---|
| Start | All weights = 0 (or random small numbers) |
| Each step | Pick a customer, predict, calculate error |
| Update | Adjust weights using gradient descent |
| | `weight = weight - learning_rate Γ— gradient` |
| Repeat | Until weights converge (error stops decreasing) |
The gradient tells the model: "This weight should go up" or "down" based on
whether increasing it would reduce the prediction error.
After training on all ~5,600 training customers for many iterations,
the weights settle to values that best separate churners from non-churners.
---
**Why this customer scores so high:**
| Feature | Contribution to z |
|---|---|
| Short tenure (-1.5 scaled) | +1.8 (weight is negative, value is negative β†’ positive contribution) |
| Month-to-month contract | +1.4 |
| Fiber optic | +0.8 |
| High charges | +0.4 |
| **Total push toward churn** | **+4.2** β†’ sigmoid β†’ 98.5% |
Every feature contributes additively. That's the key property (and limitation)
of logistic regression β€” it can't model interactions like "fiber optic is
only risky for short-tenure customers."
"""
)
# ── Naive Bayes ──────────────────────────────────────────────────────────
st.markdown("---")
st.markdown("### 2. Naive Bayes (Gaussian)")
st.markdown("*A probabilistic model that assumes features are independent.*")
st.graphviz_chart("""
digraph nb {
rankdir=LR
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"]
edge [color="#888888", penwidth=1.5]
features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"]
prob [label="Calculate Probability\\nfor each class\\nusing Bayes' Theorem", fillcolor="#e0e7ff", color="#6366f1"]
compare [label="Compare\\nP(Churn | features)\\nvs\\nP(Retain | features)", fillcolor="#fce7f3", color="#ec4899"]
output [label="Pick Higher\\nProbability", fillcolor="#d1fae5", color="#10b981"]
features -> prob -> compare -> output
}
""")
st.markdown(
"""
**How it works:**
Naive Bayes uses **Bayes' Theorem** to calculate the probability of churn given the customer's features.
The "naive" part means it assumes all features are **independent** β€” knowing tenure doesn't tell
you anything about monthly charges (which isn't true in reality, but the assumption simplifies the math).
**Formula:**
"""
)
st.latex(r"P(\text{Churn} \mid \text{features}) = \frac{P(\text{features} \mid \text{Churn}) \times P(\text{Churn})}{P(\text{features})}")
st.markdown(
"""
For continuous features (like tenure, charges), Gaussian Naive Bayes assumes each feature
follows a **normal distribution** within each class.
**Step-by-step example:**
Suppose we want to predict if Customer X will churn. They have:
- Tenure = 12 months
- Monthly Charges = $70
**Step 1:** Calculate P(Tenure=12 | Churn) and P(Tenure=12 | Retain)
From training data, we know:
- Churners have average tenure = 18 months, std = 15
- Retainers have average tenure = 38 months, std = 20
Using the Gaussian (bell curve) formula:
- P(Tenure=12 | Churn) = 0.024 (12 is close to churner average)
- P(Tenure=12 | Retain) = 0.008 (12 is far from retainer average)
**Step 2:** Do the same for Monthly Charges
- P(Charges=70 | Churn) = 0.015
- P(Charges=70 | Retain) = 0.012
**Step 3:** Multiply probabilities (the "naive" independence assumption)
- P(features | Churn) = 0.024 Γ— 0.015 = 0.00036
- P(features | Retain) = 0.008 Γ— 0.012 = 0.000096
**Step 4:** Apply Bayes' Theorem
- P(Churn | features) ∝ 0.00036 Γ— P(Churn) = 0.00036 Γ— 0.27 = 0.0000972
- P(Retain | features) ∝ 0.000096 Γ— P(Retain) = 0.000096 Γ— 0.73 = 0.00007008
Normalize: P(Churn) = 0.0000972 / (0.0000972 + 0.00007008) = **58%**
**Prediction: Churn (58% probability)**
**Why Naive Bayes is different:**
- Makes strong independence assumption (features don't interact)
- Very fast to train and predict
- Works well when features are actually somewhat independent
- Often gives different predictions than tree-based models because it doesn't
capture feature interactions (e.g., "fiber optic + short tenure" combo)
"""
)
# ── Random Forest ────────────────────────────────────────────────────────
st.markdown("---")
st.markdown("### 3. Random Forest")
st.markdown("*An ensemble of decision trees that vote together.*")
st.graphviz_chart("""
digraph rf {
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
edge [color="#888888", penwidth=1.5]
data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"]
subgraph cluster_trees {
label="200 Decision Trees (each sees a random subset)"
style=dashed
color="#94a3b8"
fontname="Helvetica"
fontsize=10
t1 [label="Tree 1\\n→ Churn", fillcolor="#d1fae5", color="#10b981"]
t2 [label="Tree 2\\n→ Retain", fillcolor="#d1fae5", color="#10b981"]
t3 [label="Tree 3\\n→ Churn", fillcolor="#d1fae5", color="#10b981"]
dots [label="...\\n(197 more)", shape=plaintext]
t200 [label="Tree 200\\n→ Churn", fillcolor="#d1fae5", color="#10b981"]
}
vote [label="Majority Vote\\n3 out of 4 say Churn\\n→ Final: Churn (75%)", fillcolor="#fef3c7", color="#f59e0b"]
data -> t1
data -> t2
data -> t3
data -> t200
t1 -> vote
t2 -> vote
t3 -> vote
t200 -> vote
}
""")
st.markdown("**How a single decision tree works:**")
st.graphviz_chart("""
digraph tree_example {
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.2,0.1"]
edge [fontname="Helvetica", fontsize=9, color="#888888", penwidth=1.5]
q1 [label="Contract = \\nMonth-to-month?", fillcolor="#e0e7ff", color="#6366f1"]
q2 [label="tenure < 12\\nmonths?", fillcolor="#e0e7ff", color="#6366f1"]
q3 [label="MonthlyCharges\\n> $70?", fillcolor="#e0e7ff", color="#6366f1"]
churn1 [label="CHURN\\n(85% probability)", fillcolor="#fee2e2", color="#ef4444"]
retain1 [label="RETAIN\\n(70% probability)", fillcolor="#d1fae5", color="#10b981"]
churn2 [label="CHURN\\n(60% probability)", fillcolor="#fee2e2", color="#ef4444"]
retain2 [label="RETAIN\\n(90% probability)", fillcolor="#d1fae5", color="#10b981"]
q1 -> q2 [label=" Yes"]
q1 -> q3 [label=" No"]
q2 -> churn1 [label=" Yes"]
q2 -> retain1 [label=" No"]
q3 -> churn2 [label=" Yes"]
q3 -> retain2 [label=" No"]
}
""")
st.markdown(
"""
**How it works:**
1. Imagine asking a series of yes/no questions: "Is the contract month-to-month?"
β†’ "Is tenure less than 12 months?" β†’ "Are charges above $70?"
That's a **decision tree** β€” it keeps splitting until it reaches an answer.
2. A single tree is easy to overfit (it memorizes the training data too closely).
So Random Forest builds **200 trees**, each trained on a **random sample**
of the data and a **random subset** of features.
3. For a new customer, all 200 trees make their prediction and we take
the **majority vote**.
**Why "Random"?** Each tree only sees a random portion of the data and features.
This diversity prevents the forest from over-relying on any single pattern.
**Why Label Encoding is fine:** Trees split on thresholds (e.g., "is Contract < 1?").
They never multiply the encoded number β€” so the integer values don't introduce false relationships.
**Strengths:** Handles non-linear patterns, resistant to overfitting, works with Label Encoding.
**Weaknesses:** Slower than Logistic Regression, less interpretable (200 trees are hard to inspect by hand).
"""
)
with st.expander("Step-by-step numerical walkthrough β€” Random Forest", expanded=False):
st.markdown(
"""
**Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month,
Internet=Fiber optic, Partner=No.
---
**Step 1 β€” Each tree sees different data**
Random Forest creates 200 trees. Each tree gets:
- A **bootstrap sample** β€” random selection of ~63% of training customers (with replacement)
- A **random subset of features** at each split (e.g., 4 out of 19 features)
| Tree | Training Customers | Features Available |
|---|---|---|
| Tree 1 | Customers #12, #45, #45, #78, #102, ... | tenure, Contract, Partner, TechSupport |
| Tree 2 | Customers #3, #22, #56, #56, #89, ... | MonthlyCharges, InternetService, gender, SeniorCitizen |
| Tree 3 | Customers #7, #33, #41, #67, #67, ... | tenure, PaymentMethod, MonthlyCharges, OnlineSecurity |
| ... | ... | ... |
Notice: Customer #45 appears twice in Tree 1 (bootstrap sampling), and each
tree considers different features.
---
**Step 2 β€” Each tree grows by finding the best splits**
**Tree 1** (has tenure, Contract, Partner, TechSupport):
The algorithm tries every possible split and picks the one that best
separates churners from non-churners (measured by **Gini impurity**):
"""
)
st.latex(r"\text{Gini} = 1 - p_{\text{churn}}^2 - p_{\text{retain}}^2")
st.markdown(
"""
| Candidate Split | Left Group (Churn%) | Right Group (Churn%) | Gini Improvement |
|---|---|---|---|
| Contract < 1 (Month-to-month) | 42% churn | 12% churn | **0.18** ← best |
| tenure < 12 | 38% churn | 20% churn | 0.14 |
| Partner = 0 | 30% churn | 24% churn | 0.03 |
Contract split wins β†’ becomes the first question.
Then each branch splits again on the remaining features, creating deeper
questions. This continues until leaves are pure or a depth limit is reached.
---
**Step 3 β€” Each tree predicts independently**
For our customer (Month-to-month, tenure=3, Fiber optic, no Partner):
| Tree | Path through the tree | Prediction |
|---|---|---|
| Tree 1 | Contract=M-t-m β†’ tenure<12 β†’ **CHURN** | Churn (92%) |
| Tree 2 | Charges>$70 β†’ Fiber optic β†’ **CHURN** | Churn (78%) |
| Tree 3 | tenure<6 β†’ Charges>$60 β†’ **CHURN** | Churn (85%) |
| Tree 4 | Contract=M-t-m β†’ No TechSupport β†’ **CHURN** | Churn (88%) |
| Tree 5 | Fiber optic β†’ tenure<24 β†’ **RETAIN** | Retain (55%) |
| ... | ... | ... |
---
**Step 4 β€” Majority vote**
Out of 200 trees:
- 172 trees predict **Churn**
- 28 trees predict **Retain**
`Probability = 172/200 = 86%` β†’ Predict **Churn**
Note: Tree 5 predicted Retain β€” that's fine. The diversity is intentional.
If every tree agreed perfectly, there would be no benefit from having 200 of them.
---
**Why Random Forest handles non-linear patterns:**
A single tree can capture "Contract=Month-to-month AND tenure<12 β†’ high churn"
without needing to encode this interaction explicitly.
The forest combines 200 different views of the data, each capturing
different interaction patterns. The majority vote smooths out individual
tree errors β€” this is why forests rarely overfit.
"""
)
# ── XGBoost ──────────────────────────────────────────────────────────────
st.markdown("---")
st.markdown("### 4. XGBoost (Extreme Gradient Boosting)")
st.markdown("*The algorithm that wins most Kaggle competitions.*")
st.graphviz_chart("""
digraph xgb {
rankdir=TB
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
edge [color="#888888", penwidth=1.5]
data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"]
t1 [label="Tree 1\\nLearns the main pattern\\n(e.g., contract type)", fillcolor="#d1fae5", color="#10b981"]
e1 [label="Errors from Tree 1\\n(customers it got wrong)", fillcolor="#fee2e2", color="#ef4444"]
t2 [label="Tree 2\\nFocuses on Tree 1's mistakes", fillcolor="#d1fae5", color="#10b981"]
e2 [label="Remaining errors", fillcolor="#fee2e2", color="#ef4444"]
t3 [label="Tree 3\\nFocuses on remaining mistakes", fillcolor="#d1fae5", color="#10b981"]
dots [label="... (200 trees total, each fixing\\nthe previous tree's errors)", shape=plaintext, fontname="Helvetica"]
final [label="Final Prediction\\nSum of all 200 trees\\n(each weighted by learning rate)", fillcolor="#fef3c7", color="#f59e0b"]
data -> t1 -> e1 -> t2 -> e2 -> t3 -> dots -> final
}
""")
st.markdown(
"""
**How it works:**
1. **Tree 1** tries to predict churn for all customers. It gets some right, some wrong.
2. **Tree 2** doesn't start from scratch β€” it specifically focuses on the customers
that Tree 1 got **wrong**. It tries to correct those mistakes.
3. **Tree 3** focuses on the remaining mistakes from Tree 1 + Tree 2 combined.
4. This continues for 200 rounds. Each new tree is a specialist in fixing
what all previous trees couldn't get right.
5. The final prediction is the **weighted sum** of all 200 trees.
**The key difference from Random Forest:**
- Random Forest: 200 trees trained **independently** (in parallel), then vote.
- XGBoost: 200 trees trained **sequentially**, each one learning from the previous one's errors.
**Why "Gradient Boosting"?** "Gradient" refers to the mathematical technique used to
determine *how* each new tree should focus on errors. It's the same gradient descent
concept used in deep learning.
**Strengths:** Usually the most accurate model, handles complex non-linear patterns,
built-in regularization prevents overfitting.
**Weaknesses:** Slower to train, harder to interpret, more hyperparameters to tune.
"""
)
with st.expander("Step-by-step numerical walkthrough β€” XGBoost", expanded=False):
st.markdown(
"""
**Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month,
Internet=Fiber optic, Partner=No.
---
**Step 1 β€” Start with a base prediction**
Before any trees, XGBoost starts with the global average:
`base prediction = overall churn rate = 0.265` (26.5%)
Convert to log-odds:
"""
)
st.latex(r"\text{log-odds} = \ln\left(\frac{0.265}{1 - 0.265}\right) = \ln(0.361) = -1.02")
st.markdown(
"""
Every customer starts at -1.02 (26.5% churn probability).
---
**Step 2 β€” Tree 1 learns from the errors**
For each customer, calculate the **residual** (how wrong the base prediction is):
| Customer | Actual | Current Prediction | Residual |
|---|---|---|---|
| #1 | Churned (1) | 0.265 | +0.735 (prediction too low) |
| #2 | Retained (0) | 0.265 | -0.265 (prediction too high) |
| #3 | Churned (1) | 0.265 | +0.735 |
| Our customer | Churned (1) | 0.265 | +0.735 |
Tree 1 tries to predict these residuals (not the original labels).
Tree 1 output for our customer: **+0.45** (it learned a partial correction).
---
**Step 3 β€” Update prediction with learning rate**
XGBoost uses a **learning rate** (Ξ· = 0.1 by default) to prevent overshooting:
`new log-odds = -1.02 + (0.1 Γ— 0.45) = -1.02 + 0.045 = -0.975`
"""
)
st.latex(r"P = \frac{1}{1 + e^{-(-0.975)}} = \frac{1}{1 + 2.65} = 0.274")
st.markdown(
"""
Probability moved from 26.5% to 27.4%. A small step in the right direction.
---
**Step 4 β€” Tree 2 focuses on remaining errors**
New residuals (using updated predictions):
| Customer | Actual | Updated Prediction | New Residual |
|---|---|---|---|
| Our customer | 1 | 0.274 | +0.726 (still too low) |
Tree 2 learns these new residuals. Output for our customer: **+0.52**
`log-odds = -0.975 + (0.1 Γ— 0.52) = -0.975 + 0.052 = -0.923`
`P = 28.4%`
---
**Step 5 β€” Continue for 200 trees**
| After Tree | Log-odds | Churn Probability |
|---|---|---|
| Base (no trees) | -1.02 | 26.5% |
| Tree 1 | -0.975 | 27.4% |
| Tree 2 | -0.923 | 28.4% |
| Tree 10 | -0.42 | 39.7% |
| Tree 50 | 0.85 | 70.1% |
| Tree 100 | 1.72 | 84.8% |
| Tree 200 | 2.95 | 95.0% |
Each tree adds a small correction. After 200 trees, the probability
climbed from 26.5% to 95.0%.
---
**The final prediction formula:**
"""
)
st.latex(r"\text{prediction} = \text{base} + \eta \cdot f_1(x) + \eta \cdot f_2(x) + \ldots + \eta \cdot f_{200}(x)")
st.markdown(
"""
Where each f(x) is a small decision tree focused on the remaining errors.
---
**Key differences from Random Forest:**
| | Random Forest | XGBoost |
|---|---|---|
| **Trees learn** | Independently (in parallel) | Sequentially (each from previous errors) |
| **Each tree predicts** | The original label (Churn/Retain) | The **residual error** from all previous trees |
| **Learning rate** | Not applicable | Controls step size (Ξ·=0.1 is typical) |
| **Why it helps** | Diversity from random sampling | Each tree is a specialist in remaining errors |
| **Risk** | Hard to overfit | Can overfit if too many trees or learning rate too high |
The learning rate is crucial: without it (Ξ·=1.0), early trees would over-correct
and the model would overfit. Small steps (Ξ·=0.1) force the model to build
gradually, which produces better generalization.
"""
)
# ── SHAP ─────────────────────────────────────────────────────────────────
st.markdown("---")
st.markdown("### 4. SHAP β€” How We Explain Predictions")
st.markdown("*Making the black box transparent.*")
st.graphviz_chart("""
digraph shap {
rankdir=LR
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
edge [color="#888888", penwidth=1.5]
base [label="Base Rate\\n26.5% churn\\n(average customer)", fillcolor="#e0e7ff", color="#6366f1"]
f1 [label="Contract =\\nMonth-to-month\\n+18%", fillcolor="#fee2e2", color="#ef4444"]
f2 [label="tenure =\\n2 months\\n+12%", fillcolor="#fee2e2", color="#ef4444"]
f3 [label="TechSupport =\\nNo\\n+5%", fillcolor="#fee2e2", color="#ef4444"]
f4 [label="TotalCharges =\\n$150 (low)\\n+3%", fillcolor="#fee2e2", color="#ef4444"]
f5 [label="Partner = Yes\\n-4%", fillcolor="#d1fae5", color="#10b981"]
pred [label="Final Prediction\\n60.5% churn", fillcolor="#fef3c7", color="#f59e0b"]
base -> f1 -> f2 -> f3 -> f4 -> f5 -> pred
}
""")
st.markdown(
"""
**How SHAP works:**
Every prediction starts from the **base rate** β€” the overall churn rate in the data (~26.5%).
Then SHAP shows how each feature **pushes** that prediction up or down:
- Month-to-month contract β†’ pushes probability **up** (toward churn)
- Low tenure β†’ pushes probability **up** (new customer, high risk)
- Having a partner β†’ pushes probability **down** (slightly protective)
- Each feature gets a + or - contribution, and they all add up to the final prediction.
**Why this matters for business:** If the model predicts an 80% churn probability,
SHAP tells you *why* β€” "It's mainly because they're on a month-to-month contract
and only been with us for 2 months." That's actionable: offer them a yearly contract
with a discount.
**The name "SHAP"** stands for SHapley Additive exPlanations, based on Shapley values
from game theory β€” a mathematically rigorous way to fairly distribute credit among features.
"""
)
with st.expander("Step-by-step numerical walkthrough β€” SHAP values", expanded=False):
st.markdown(
"""
SHAP uses a concept from game theory: **Shapley values**. The idea is to
fairly distribute credit among "players" (features) for the "payout" (prediction).
**The question:** XGBoost predicts 95% churn for our customer. But *which features
caused this?* And by how much?
---
**Step 1 β€” Start from the base value**
The base value is the average model output across all training customers:
`base value = 26.5%` (overall churn rate)
---
**Step 2 β€” Measure each feature's marginal contribution**
For each feature, SHAP asks: "If I remove this feature, how much does the
prediction change?" But it's more sophisticated β€” it checks this across
**every possible combination** of other features.
**Example for "Contract = Month-to-month":**
| Features included (besides Contract) | Prediction with Contract | Prediction without | Contribution of Contract |
|---|---|---|---|
| None | 45% | 26.5% | +18.5% |
| tenure only | 52% | 35% | +17% |
| tenure + MonthlyCharges | 68% | 48% | +20% |
| tenure + MonthlyCharges + Internet | 82% | 60% | +22% |
| All features | 95% | 73% | +22% |
The SHAP value for Contract = **average of all marginal contributions**
= (18.5 + 17 + 20 + 22 + 22) / 5 = **+19.9%**
---
**Step 3 β€” Do this for every feature**
| Feature | SHAP Value | Direction |
|---|---|---|
| Contract = Month-to-month | +19.9% | Pushes toward churn |
| tenure = 3 months | +15.2% | Pushes toward churn |
| InternetService = Fiber optic | +11.8% | Pushes toward churn |
| MonthlyCharges = $85 | +8.4% | Pushes toward churn |
| TechSupport = No | +6.1% | Pushes toward churn |
| OnlineSecurity = No | +4.2% | Pushes toward churn |
| Partner = No | +2.1% | Pushes toward churn |
| PaymentMethod = Electronic check | +1.3% | Pushes toward churn |
| Other features | -0.5% | Small protective effects |
---
**Step 4 β€” Everything adds up to the prediction**
`26.5% (base) + 19.9% + 15.2% + 11.8% + 8.4% + 6.1% + 4.2% + 2.1% + 1.3% - 0.5% = 95.0%`
This is guaranteed by the Shapley value properties β€” the contributions
always sum exactly to the difference between the base value and the prediction.
---
**Why Shapley values are "fair":**
Shapley values are the *only* method that satisfies all four fairness properties:
1. **Efficiency** β€” contributions sum to the total prediction
2. **Symmetry** β€” features with identical effects get identical values
3. **Null player** β€” irrelevant features get zero contribution
4. **Linearity** β€” the method is consistent across combined models
This is why SHAP is preferred over simpler methods like "feature importance" β€”
it's mathematically guaranteed to distribute credit fairly.
"""
)
# ── Comparison Table ─────────────────────────────────────────────────────
st.markdown("---")
st.markdown("### Quick Comparison")
st.markdown(
"""
| | Logistic Regression | Random Forest | XGBoost |
|---|---|---|---|
| **How it learns** | Finds the best weights for a linear equation | Builds many independent trees and averages their votes | Builds trees sequentially, each correcting the last |
| **Encoding** | One-Hot (needs binary columns) | Label (integers are fine) | Label (integers are fine) |
| **Speed** | Very fast | Moderate | Slower |
| **Accuracy** | Good baseline | Very good | Usually best |
| **Interpretability** | High (weights = feature importance) | Medium (feature importance available) | Medium (needs SHAP for full explanation) |
| **Best for** | Simple, linear relationships | Non-linear patterns with moderate data | Complex patterns, competitions, production systems |
"""
)
# ── Why Each Model ────────────────────────────────────────────────────────
st.markdown("---")
st.markdown("### Why We Chose Each Model")
st.markdown(
"""
#### Logistic Regression β€” The Baseline
| | |
|---|---|
| **Why we included it** | Every ML project needs a simple baseline to compare against. If a complex model can't beat Logistic Regression, the complexity isn't justified. |
| **Where it's used in industry** | Credit scoring (banks are required to use interpretable models), medical diagnosis screening, spam detection, ad click prediction at scale. |
| **Key benefit** | Full transparency β€” every feature gets a weight you can show to stakeholders. "Month-to-month contract increases churn probability by X%." |
| **Limitation** | Assumes each feature contributes independently. Can't learn "fiber optic is only risky when tenure is short" without manually creating that interaction feature. |
| **When to pick this** | When you need to explain *why* to regulators or non-technical stakeholders, or when you need a fast model that runs on millions of rows per second. |
---
#### Naive Bayes β€” The Probabilistic Baseline
| | |
|---|---|
| **Why we included it** | Provides a fundamentally different approach based on probability theory. Makes the independence assumption explicit, which helps identify when features actually interact. |
| **Where it's used in industry** | Email spam filtering (Gmail), document classification, real-time sentiment analysis, medical diagnosis screening. |
| **Key benefit** | Extremely fast training and prediction. Works well with small datasets. Naturally outputs calibrated probabilities. |
| **Limitation** | The independence assumption is often violated (e.g., contract type and monthly charges are correlated). Can't capture feature interactions. |
| **When to pick this** | When speed matters more than accuracy, or as a baseline to test if feature interactions are important (if NB matches complex models, interactions don't matter). |
---
#### Random Forest β€” The Reliable Workhorse
| | |
|---|---|
| **Why we included it** | It captures non-linear patterns and feature interactions that Logistic Regression and Naive Bayes miss, without needing manual feature engineering. |
| **Where it's used in industry** | Fraud detection (PayPal), customer segmentation, manufacturing defect prediction, insurance risk assessment. |
| **Key benefit** | Robust to outliers and noise. Rarely overfits. Handles mixed feature types without scaling or encoding concerns. |
| **Limitation** | Slower inference than Logistic Regression or Naive Bayes. The 200-tree ensemble is harder to explain to non-technical audiences. |
| **When to pick this** | When you need reliable accuracy with minimal tuning. It's the "safe choice" β€” almost always performs well without surprises. |
---
#### XGBoost β€” The Performance Leader
| | |
|---|---|
| **Why we included it** | It's consistently the top-performing algorithm for structured/tabular data like our customer dataset. |
| **Where it's used in industry** | Won most Kaggle competitions (2015–present). Used at Airbnb (search ranking), Uber (ETA prediction), major banks (credit risk). |
| **Key benefit** | Sequential error correction means it learns from its own mistakes. Built-in regularization prevents overfitting. |
| **Limitation** | More hyperparameters to tune. Slower to train than Random Forest. Requires SHAP for interpretability. |
| **When to pick this** | When predictive accuracy is the priority and you have time to tune. For production churn models, XGBoost is typically the default choice. |
---
#### SHAP β€” The Explainability Layer
| | |
|---|---|
| **Why we included it** | A prediction without an explanation has no business value. SHAP makes any model interpretable. |
| **Where it's used in industry** | Required by EU regulations (GDPR "right to explanation"). Used at all major banks for credit decisions. Standard at tech companies for model debugging. |
| **Key benefit** | Works with *any* model. Gives both global importance (which features matter overall) and local explanations (why *this* customer was flagged). |
| **When to use** | Always. SHAP should be part of every production ML system. |
---
#### How They Complement Each Other
| Role | Model | Why |
|---|---|---|
| **Linear baseline** | Logistic Regression | Sets the floor β€” any model must beat this to justify its complexity |
| **Probabilistic baseline** | Naive Bayes | Tests if feature independence assumption holds. If NB performs poorly, we know interactions matter. |
| **Production model** | XGBoost | Highest accuracy for deployment |
| **Backup / ensemble** | Random Forest | If XGBoost overfits on new data, RF is a stable alternative |
| **Explainability** | SHAP on XGBoost | Turns the best model into something stakeholders can act on |
| **Live updates** | SGDClassifier (Live page) | The only model supporting `partial_fit` for incremental learning |
In a real deployment, you would use XGBoost as the primary model with SHAP
for explanations, keep Random Forest as a monitoring baseline, compare against
Naive Bayes to validate that feature interactions are being captured, and use the
SGDClassifier pattern for real-time drift adaptation.
---
#### Why Logistic Regression Can Match or Beat Complex Models on This Dataset
If Logistic Regression performs as well as XGBoost here, that's an important
finding β€” not a flaw.
| Factor | Why It Helps LR |
|---|---|
| **Small dataset (~5,600 training rows)** | XGBoost has far more parameters (200 trees Γ— many splits) than LR (~20 weights). With limited data, complex models risk overfitting. LR's simplicity is an advantage. |
| **Linear churn signal** | Contract type, tenure, and charges have straightforward relationships with churn. These are essentially linear effects β€” exactly what LR is designed for. |
| **One-hot encoding** | LR receives one-hot encoded features (each category as its own binary column). This avoids false ordinal relationships that label encoding introduces for linear models. |
| **Strong regularization** | LR's built-in L2 regularization prevents any single feature from dominating, keeping the model stable. |
**What this means in practice:** A simpler, fully interpretable model achieving
top accuracy is a **win**. Every prediction can be explained to the retention team,
the model is faster to deploy, and regulatory compliance is easier.
**When would XGBoost pull ahead?** With a much larger dataset (100K+ customers)
and richer features (browsing behavior, support call logs, time-series usage patterns),
XGBoost would exploit non-linear interactions that LR can't capture β€” like
"fiber optic is only risky for short-tenure customers without tech support."
"""
)
# ── Tab 0: Predict and Compare ───────────────────────────────────────────────
with tab_predict:
st.subheader("Predict and Compare β€” Live Demo")
st.markdown(
"Below are **5 real customers** from the test set (data the models have never seen during training). "
"Click **Run Predictions** to see what each model thinks β€” then compare against what actually happened."
)
raw_df_pred = load_raw_data()
churned_idx = y_test[y_test == 1].index[:3]
retained_idx = y_test[y_test == 0].index[:2]
demo_idx = churned_idx.tolist() + retained_idx.tolist()
display_columns = [
"customerID", "gender", "SeniorCitizen", "tenure", "Contract",
"InternetService", "MonthlyCharges", "TotalCharges",
]
demo_display = raw_df_pred.loc[demo_idx, display_columns].copy()
demo_display.index = range(1, len(demo_display) + 1)
demo_display.index.name = "#"
st.markdown("#### Customer Profiles")
st.dataframe(demo_display, use_container_width=True)
if st.button("Run Predictions", type="primary", use_container_width=True):
st.markdown("---")
st.markdown("#### Prediction Results")
model_names = list(all_models.keys())
# Display column headers
header_cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5])
header_cols[0].markdown("**#**")
header_cols[1].markdown("**Customer ID**")
for j, mn in enumerate(model_names):
header_cols[j + 2].markdown(f"**{mn}**")
header_cols[len(model_names) + 2].markdown("**Actual**")
header_cols[len(model_names) + 3].markdown("**βœ“**")
st.markdown("---")
results_rows = []
for pos, idx in enumerate(demo_idx, 1):
actual_val = y_test.loc[idx]
actual_label = "Churned" if actual_val == 1 else "Retained"
row_result = {
"#": pos,
"Customer ID": raw_df_pred.loc[idx, "customerID"],
}
all_correct = True
all_wrong = True
for model_name in model_names:
model_obj = all_models[model_name]
x_data = model_test_data[model_name]
row_enc = x_data.loc[[idx]]
pred = int(model_obj.predict(row_enc)[0])
proba = float(model_obj.predict_proba(row_enc)[0][1])
pred_label = "Churned" if pred == 1 else "Retained"
correct = bool(pred == actual_val)
if correct:
all_wrong = False
else:
all_correct = False
row_result[f"{model_name}"] = f"{pred_label} ({proba:.0%})"
row_result[f"{model_name}_correct"] = correct
row_result["Actual"] = actual_label
if all_correct:
row_result["_status"] = "all_correct"
elif all_wrong:
row_result["_status"] = "all_wrong"
else:
row_result["_status"] = "mixed"
results_rows.append(row_result)
results_df = pd.DataFrame(results_rows)
for _, row in results_df.iterrows():
status = row["_status"]
if status == "all_correct":
icon = "βœ…"
elif status == "all_wrong":
icon = "❌"
else:
icon = "⚠️"
cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5])
cols[0].markdown(f"**{row['#']}**")
cols[1].markdown(f"`{row['Customer ID']}`")
for j, mn in enumerate(model_names):
correct = row[f"{mn}_correct"]
mark = "βœ“" if correct else "βœ—"
cols[j + 2].markdown(
f"{'🟒' if correct else 'πŸ”΄'} {row[mn]} {mark}"
)
cols[len(model_names) + 2].markdown(f"**{row['Actual']}**")
cols[len(model_names) + 3].markdown(icon)
st.markdown("---")
n_total = len(results_rows)
n_all_correct = sum(1 for r in results_rows if r["_status"] == "all_correct")
n_all_wrong = sum(1 for r in results_rows if r["_status"] == "all_wrong")
n_mixed = n_total - n_all_correct - n_all_wrong
rc1, rc2, rc3 = st.columns(3)
rc1.metric("All Models Correct", f"{n_all_correct} / {n_total}", delta="βœ…")
rc2.metric("Mixed Results", f"{n_mixed} / {n_total}", delta="⚠️" if n_mixed > 0 else None)
rc3.metric("All Models Wrong", f"{n_all_wrong} / {n_total}", delta="❌" if n_all_wrong > 0 else None, delta_color="inverse")
st.markdown(
"**Legend:** 🟒 = correct prediction, πŸ”΄ = wrong prediction. "
"Percentage shown is the model's estimated churn probability."
)
# ── Tab 1: Model Comparison ──────────────────────────────────────────────────
with tab_compare:
st.subheader("Performance Metrics")
metrics_df = pd.DataFrame(metrics).T
metrics_df = metrics_df.round(3)
best_model = metrics_df["AUC"].idxmax()
st.info(f"Best model by AUC: **{best_model}** ({metrics_df.loc[best_model, 'AUC']:.3f})")
encoding_col = pd.Series({
"Logistic Regression": "One-Hot",
"Random Forest": "Label",
"XGBoost": "Label",
}, name="Encoding")
display_metrics = pd.concat([encoding_col, metrics_df], axis=1)
st.dataframe(
display_metrics.style.highlight_max(axis=0, subset=metrics_df.columns, color="#c6efce"),
use_container_width=True,
)
st.markdown("---")
st.subheader("ROC Curves")
roc_entries = [(name, all_models[name], model_test_data[name]) for name in all_models]
st.plotly_chart(plot_roc_curves(roc_entries, y_test), use_container_width=True)
st.markdown("---")
st.subheader("Confusion Matrices")
cm_cols = st.columns(4)
for i, name in enumerate(all_models):
with cm_cols[i]:
y_pred = all_models[name].predict(model_test_data[name])
st.plotly_chart(
plot_confusion_matrix(y_test, y_pred, title=name),
use_container_width=True,
)
# ── Tab 2: Global SHAP ──────────────────────────────────────────────────────
with tab_shap_global:
st.subheader("Global Feature Importance β€” XGBoost")
st.markdown("SHAP values show how much each feature pushes the prediction toward or away from churn.")
xgb_model = all_models["XGBoost"]
explainer, shap_values = get_shap_explainer(xgb_model, X_train)
fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
shap.plots.bar(shap_values, max_display=15, show=False, ax=ax_bar)
st.pyplot(fig_bar)
st.markdown("---")
st.markdown("**Beeswarm Plot** β€” Each dot is a customer. Color = feature value (red = high, blue = low).")
fig_bee, ax_bee = plt.subplots(figsize=(10, 8))
shap.plots.beeswarm(shap_values, max_display=15, show=False)
st.pyplot(plt.gcf())
# ── Tab 3: Individual Explanations ───────────────────────────────────────────
with tab_shap_individual:
st.subheader("Explain a Single Customer's Prediction")
raw_df = load_raw_data()
customer_ids = raw_df["customerID"].values
selected_id = st.selectbox("Select Customer ID", customer_ids[:200])
idx_in_raw = raw_df[raw_df["customerID"] == selected_id].index[0]
enc_df, _ = get_encoded_data()
if idx_in_raw in X_test.index:
row = X_test.loc[[idx_in_raw]]
actual = y_test.loc[idx_in_raw]
elif idx_in_raw in X_train.index:
row = X_train.loc[[idx_in_raw]]
actual = y_train.loc[idx_in_raw]
else:
row = enc_df.loc[[idx_in_raw], feature_cols]
actual = enc_df.loc[idx_in_raw, "Churn"]
proba = xgb_model.predict_proba(row)[0][1]
c1, c2 = st.columns(2)
c1.metric("Predicted Churn Probability", f"{proba:.1%}")
c2.metric("Actual Outcome", "Churned" if actual == 1 else "Retained")
st.markdown("**Customer Details (raw values):**")
st.dataframe(raw_df[raw_df["customerID"] == selected_id].T, use_container_width=True)
st.markdown("---")
st.markdown("**SHAP Waterfall β€” Why this prediction?**")
sv = get_shap_single(explainer, row)
fig_wf, ax_wf = plt.subplots(figsize=(10, 6))
shap.plots.waterfall(sv[0], max_display=12, show=False)
st.pyplot(plt.gcf())
# ── Tab 4: What-If Predictor ─────────────────────────────────────────────────
with tab_whatif:
st.subheader("What-If Predictor")
st.markdown("Adjust customer features and see how churn probability changes in real time.")
raw_df = load_raw_data()
wi_col1, wi_col2, wi_col3 = st.columns(3)
with wi_col1:
wi_gender = st.selectbox("Gender", ["Female", "Male"], key="wi_gender")
wi_senior = st.selectbox("Senior Citizen", [0, 1], key="wi_senior")
wi_partner = st.selectbox("Partner", ["Yes", "No"], key="wi_partner")
wi_dependents = st.selectbox("Dependents", ["Yes", "No"], key="wi_dep")
wi_tenure = st.slider("Tenure (months)", 0, 72, 12, key="wi_tenure")
with wi_col2:
wi_phone = st.selectbox("Phone Service", ["Yes", "No"], key="wi_phone")
wi_multi = st.selectbox("Multiple Lines", ["No", "Yes", "No phone service"], key="wi_multi")
wi_internet = st.selectbox("Internet Service", ["DSL", "Fiber optic", "No"], key="wi_inet")
wi_security = st.selectbox("Online Security", ["Yes", "No", "No internet service"], key="wi_sec")
wi_backup = st.selectbox("Online Backup", ["Yes", "No", "No internet service"], key="wi_bak")
with wi_col3:
wi_protection = st.selectbox("Device Protection", ["Yes", "No", "No internet service"], key="wi_prot")
wi_support = st.selectbox("Tech Support", ["Yes", "No", "No internet service"], key="wi_sup")
wi_tv = st.selectbox("Streaming TV", ["Yes", "No", "No internet service"], key="wi_tv")
wi_movies = st.selectbox("Streaming Movies", ["Yes", "No", "No internet service"], key="wi_mov")
wi_contract = st.selectbox("Contract", ["Month-to-month", "One year", "Two year"], key="wi_con")
wi_col4, wi_col5 = st.columns(2)
with wi_col4:
wi_paperless = st.selectbox("Paperless Billing", ["Yes", "No"], key="wi_paper")
wi_payment = st.selectbox(
"Payment Method",
["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"],
key="wi_pay",
)
with wi_col5:
wi_monthly = st.slider("Monthly Charges ($)", 18.0, 120.0, 70.0, step=0.5, key="wi_monthly")
wi_total = st.slider("Total Charges ($)", 18.0, 9000.0, 1500.0, step=10.0, key="wi_total")
input_dict = {
"gender": wi_gender, "Partner": wi_partner, "Dependents": wi_dependents,
"PhoneService": wi_phone, "MultipleLines": wi_multi, "InternetService": wi_internet,
"OnlineSecurity": wi_security, "OnlineBackup": wi_backup,
"DeviceProtection": wi_protection, "TechSupport": wi_support,
"StreamingTV": wi_tv, "StreamingMovies": wi_movies, "Contract": wi_contract,
"PaperlessBilling": wi_paperless, "PaymentMethod": wi_payment,
}
numeric_dict = {
"SeniorCitizen": wi_senior, "tenure": wi_tenure,
"MonthlyCharges": wi_monthly, "TotalCharges": wi_total,
}
_, enc_map = get_encoded_data()
encoded_input = {}
for col, val in input_dict.items():
le = enc_map[col]
if val in le.classes_:
encoded_input[col] = le.transform([val])[0]
else:
encoded_input[col] = 0
encoded_input.update(numeric_dict)
input_row = pd.DataFrame([encoded_input])[feature_cols]
wi_proba = xgb_model.predict_proba(input_row)[0][1]
st.markdown("---")
res_col1, res_col2 = st.columns([1, 2])
with res_col1:
st.metric("Churn Probability", f"{wi_proba:.1%}")
if wi_proba < 0.3:
st.success("Low risk β€” customer likely to stay")
elif wi_proba < 0.6:
st.warning("Medium risk β€” consider retention offer")
else:
st.error("High risk β€” immediate intervention recommended")
with res_col2:
st.plotly_chart(plot_gauge(wi_proba), use_container_width=True)
st.markdown("---")
st.markdown("**Top Feature Drivers for This Configuration:**")
sv_wi = get_shap_single(explainer, input_row)
fig_wi, ax_wi = plt.subplots(figsize=(10, 5))
shap.plots.waterfall(sv_wi[0], max_display=10, show=False)
st.pyplot(plt.gcf())