Spaces:
Sleeping
Sleeping
File size: 56,802 Bytes
f3a6f24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 | import streamlit as st
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
from utils.data_loader import (
load_raw_data, get_train_test, get_onehot_train_test,
get_encoded_data, CATEGORICAL_COLS, NUMERIC_COLS,
)
from utils.models import (
train_tree_models, train_lr_model, train_nb_model, evaluate_model,
get_shap_explainer, get_shap_single,
)
from utils.visualizations import plot_roc_curves, plot_confusion_matrix, plot_gauge
st.set_page_config(page_title="Churn Models", page_icon="π€", layout="wide")
st.title("Churn Prediction Models")
st.markdown("---")
X_train, X_test, y_train, y_test, encoders, feature_cols = get_train_test()
X_train_oh, X_test_oh, _, _, feature_cols_oh = get_onehot_train_test()
# Initialize session state for models
if "churn_models_trained" not in st.session_state:
st.session_state.churn_models_trained = False
st.session_state.all_models = None
st.session_state.model_test_data = None
st.session_state.metrics = None
st.info(
"**Encoding note:** Logistic Regression is trained on **One-Hot Encoded** features "
f"({len(feature_cols_oh)} columns) while Random Forest and XGBoost use "
f"**Label Encoding** ({len(feature_cols)} columns). "
"Each model gets the encoding that's optimal for it."
)
if not st.session_state.churn_models_trained:
st.warning("β οΈ Models not trained yet. Click the button below to train all models.")
if st.button("π Train All Models", type="primary", use_container_width=True):
with st.spinner("Training models..."):
tree_models = train_tree_models(X_train, y_train)
lr_model = train_lr_model(X_train_oh, y_train)
nb_model = train_nb_model(X_train, y_train)
all_models = {
"Logistic Regression": lr_model,
"Naive Bayes": nb_model,
}
all_models.update(tree_models)
model_test_data = {
"Logistic Regression": X_test_oh,
"Naive Bayes": X_test,
"Random Forest": X_test,
"XGBoost": X_test,
}
metrics = {}
for name, model in all_models.items():
metrics[name] = evaluate_model(model, model_test_data[name], y_test)
st.session_state.all_models = all_models
st.session_state.model_test_data = model_test_data
st.session_state.metrics = metrics
st.session_state.churn_models_trained = True
st.rerun()
st.stop()
# Models are trained - retrieve from session state
all_models = st.session_state.all_models
model_test_data = st.session_state.model_test_data
metrics = st.session_state.metrics
st.success("β
All models trained and ready!")
tab_how, tab_predict, tab_compare, tab_shap_global, tab_shap_individual, tab_whatif = st.tabs(
["How The Algorithms Work", "Predict and Compare", "Model Comparison", "Feature Importance (SHAP)", "Individual Explanations", "What-If Predictor"]
)
# ββ Tab: How The Algorithms Work βββββββββββββββββββββββββββββββββββββββββββββ
with tab_how:
st.subheader("How Each Algorithm Works β A Visual Guide")
st.markdown(
"We use four very different algorithms to predict churn. Each approaches "
"the problem in its own way. Understanding the differences helps explain "
"why one model might outperform another and when predictions diverge."
)
# ββ Logistic Regression ββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### 1. Logistic Regression")
st.markdown("*The simplest model β and often a strong baseline.*")
st.graphviz_chart("""
digraph lr {
rankdir=LR
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"]
edge [color="#888888", penwidth=1.5]
features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"]
weighted [label="Weighted Sum\\nwβΓtenure + wβΓcharges\\n+ wβΓcontract + ...", fillcolor="#e0e7ff", color="#6366f1"]
sigmoid [label="Sigmoid Function\\nSquash to 0-1", fillcolor="#fce7f3", color="#ec4899"]
output [label="Probability\\n0.82 β Churn\\n0.15 β Retain", fillcolor="#d1fae5", color="#10b981"]
features -> weighted -> sigmoid -> output
}
""")
import plotly.graph_objects as go
x_sig = np.linspace(-8, 8, 200)
y_sig = 1 / (1 + np.exp(-x_sig))
fig_sig = go.Figure()
fig_sig.add_trace(go.Scatter(x=x_sig, y=y_sig, mode="lines", line=dict(color="#6366f1", width=3), name="Sigmoid"))
fig_sig.add_hline(y=0.5, line_dash="dash", line_color="gray", annotation_text="Decision boundary (0.5)")
fig_sig.add_vrect(x0=-8, x1=0, fillcolor="#d1fae5", opacity=0.15, annotation_text="Retain", annotation_position="bottom left")
fig_sig.add_vrect(x0=0, x1=8, fillcolor="#fee2e2", opacity=0.15, annotation_text="Churn", annotation_position="bottom right")
fig_sig.update_layout(
title="The Sigmoid Curve β Turns Any Number into a Probability",
xaxis_title="Weighted Sum of Features (z = wβxβ + wβxβ + ...)",
yaxis_title="Churn Probability",
height=350,
yaxis=dict(range=[0, 1]),
)
st.plotly_chart(fig_sig, use_container_width=True)
st.markdown(
"""
**How it works:**
1. Each customer feature gets a **weight** (a number the model learns). For example,
"month-to-month contract" might get a high positive weight (pushes toward churn),
while "long tenure" gets a negative weight (pushes away from churn).
2. The model adds up: `weightβ Γ featureβ + weightβ Γ featureβ + ...`
3. This sum could be any number (-β to +β). The **sigmoid function** (the S-curve above)
squashes it into a probability between 0 and 1.
4. If the probability > 0.5 β predict "Churn". Otherwise β "Retain".
**Why we use One-Hot Encoding for this model:** Because the weighted sum treats numbers
at face value. If we encoded "Month-to-month"=2 and "Two year"=0, the model would
think month-to-month is "more" of something than two-year β which is nonsensical.
One-Hot avoids this by giving each category its own yes/no column.
**Strengths:** Fast, interpretable (weights directly tell you what matters), good baseline.
**Weaknesses:** Assumes a linear relationship between features and the log-odds of churn.
Can't capture complex interactions (e.g., "fiber optic is only risky for short-tenure customers").
"""
)
with st.expander("Step-by-step numerical walkthrough β Logistic Regression", expanded=False):
st.markdown(
"""
Let's trace through exactly how the model makes one prediction.
**A single customer:**
| Feature | Value |
|---|---|
| tenure | 3 months |
| MonthlyCharges | $85 |
| Contract_Month-to-month | 1 (one-hot) |
| Contract_One year | 0 |
| Contract_Two year | 0 |
| InternetService_Fiber optic | 1 |
| Partner | 0 (no) |
(After one-hot encoding and scaling)
---
**Step 1 β The model has learned weights (one per feature)**
| Feature | Learned Weight | Meaning |
|---|---|---|
| tenure (scaled) | -1.2 | Longer tenure β less churn |
| MonthlyCharges (scaled) | +0.5 | Higher charges β more churn |
| Contract_Month-to-month | +1.4 | Month-to-month β high churn risk |
| Contract_Two year | -0.9 | Two-year contract β protective |
| InternetService_Fiber optic | +0.8 | Fiber optic β correlated with churn |
| Partner | -0.3 | Having a partner β slightly protective |
| (bias term) | -0.2 | Baseline offset |
---
**Step 2 β Compute the weighted sum (z)**
"""
)
st.latex(r"z = b + w_1 x_1 + w_2 x_2 + w_3 x_3 + \ldots")
st.markdown(
"""
For our customer (tenure scaled to -1.5 since 3 months is below average):
`z = -0.2 + (-1.2 Γ -1.5) + (0.5 Γ 0.8) + (1.4 Γ 1) + (0.8 Γ 1) + (-0.3 Γ 0)`
`z = -0.2 + 1.8 + 0.4 + 1.4 + 0.8 + 0`
`z = 4.2`
---
**Step 3 β Apply the sigmoid function**
"""
)
st.latex(r"P(\text{churn}) = \frac{1}{1 + e^{-z}} = \frac{1}{1 + e^{-4.2}} = \frac{1}{1 + 0.015} = 0.985")
st.markdown(
"""
**Probability = 98.5%** β Predict **Churn**
---
**Step 4 β How did the model learn these weights?**
| Epoch | What happens |
|---|---|
| Start | All weights = 0 (or random small numbers) |
| Each step | Pick a customer, predict, calculate error |
| Update | Adjust weights using gradient descent |
| | `weight = weight - learning_rate Γ gradient` |
| Repeat | Until weights converge (error stops decreasing) |
The gradient tells the model: "This weight should go up" or "down" based on
whether increasing it would reduce the prediction error.
After training on all ~5,600 training customers for many iterations,
the weights settle to values that best separate churners from non-churners.
---
**Why this customer scores so high:**
| Feature | Contribution to z |
|---|---|
| Short tenure (-1.5 scaled) | +1.8 (weight is negative, value is negative β positive contribution) |
| Month-to-month contract | +1.4 |
| Fiber optic | +0.8 |
| High charges | +0.4 |
| **Total push toward churn** | **+4.2** β sigmoid β 98.5% |
Every feature contributes additively. That's the key property (and limitation)
of logistic regression β it can't model interactions like "fiber optic is
only risky for short-tenure customers."
"""
)
# ββ Naive Bayes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### 2. Naive Bayes (Gaussian)")
st.markdown("*A probabilistic model that assumes features are independent.*")
st.graphviz_chart("""
digraph nb {
rankdir=LR
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"]
edge [color="#888888", penwidth=1.5]
features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"]
prob [label="Calculate Probability\\nfor each class\\nusing Bayes' Theorem", fillcolor="#e0e7ff", color="#6366f1"]
compare [label="Compare\\nP(Churn | features)\\nvs\\nP(Retain | features)", fillcolor="#fce7f3", color="#ec4899"]
output [label="Pick Higher\\nProbability", fillcolor="#d1fae5", color="#10b981"]
features -> prob -> compare -> output
}
""")
st.markdown(
"""
**How it works:**
Naive Bayes uses **Bayes' Theorem** to calculate the probability of churn given the customer's features.
The "naive" part means it assumes all features are **independent** β knowing tenure doesn't tell
you anything about monthly charges (which isn't true in reality, but the assumption simplifies the math).
**Formula:**
"""
)
st.latex(r"P(\text{Churn} \mid \text{features}) = \frac{P(\text{features} \mid \text{Churn}) \times P(\text{Churn})}{P(\text{features})}")
st.markdown(
"""
For continuous features (like tenure, charges), Gaussian Naive Bayes assumes each feature
follows a **normal distribution** within each class.
**Step-by-step example:**
Suppose we want to predict if Customer X will churn. They have:
- Tenure = 12 months
- Monthly Charges = $70
**Step 1:** Calculate P(Tenure=12 | Churn) and P(Tenure=12 | Retain)
From training data, we know:
- Churners have average tenure = 18 months, std = 15
- Retainers have average tenure = 38 months, std = 20
Using the Gaussian (bell curve) formula:
- P(Tenure=12 | Churn) = 0.024 (12 is close to churner average)
- P(Tenure=12 | Retain) = 0.008 (12 is far from retainer average)
**Step 2:** Do the same for Monthly Charges
- P(Charges=70 | Churn) = 0.015
- P(Charges=70 | Retain) = 0.012
**Step 3:** Multiply probabilities (the "naive" independence assumption)
- P(features | Churn) = 0.024 Γ 0.015 = 0.00036
- P(features | Retain) = 0.008 Γ 0.012 = 0.000096
**Step 4:** Apply Bayes' Theorem
- P(Churn | features) β 0.00036 Γ P(Churn) = 0.00036 Γ 0.27 = 0.0000972
- P(Retain | features) β 0.000096 Γ P(Retain) = 0.000096 Γ 0.73 = 0.00007008
Normalize: P(Churn) = 0.0000972 / (0.0000972 + 0.00007008) = **58%**
**Prediction: Churn (58% probability)**
**Why Naive Bayes is different:**
- Makes strong independence assumption (features don't interact)
- Very fast to train and predict
- Works well when features are actually somewhat independent
- Often gives different predictions than tree-based models because it doesn't
capture feature interactions (e.g., "fiber optic + short tenure" combo)
"""
)
# ββ Random Forest ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### 3. Random Forest")
st.markdown("*An ensemble of decision trees that vote together.*")
st.graphviz_chart("""
digraph rf {
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
edge [color="#888888", penwidth=1.5]
data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"]
subgraph cluster_trees {
label="200 Decision Trees (each sees a random subset)"
style=dashed
color="#94a3b8"
fontname="Helvetica"
fontsize=10
t1 [label="Tree 1\\nβ Churn", fillcolor="#d1fae5", color="#10b981"]
t2 [label="Tree 2\\nβ Retain", fillcolor="#d1fae5", color="#10b981"]
t3 [label="Tree 3\\nβ Churn", fillcolor="#d1fae5", color="#10b981"]
dots [label="...\\n(197 more)", shape=plaintext]
t200 [label="Tree 200\\nβ Churn", fillcolor="#d1fae5", color="#10b981"]
}
vote [label="Majority Vote\\n3 out of 4 say Churn\\nβ Final: Churn (75%)", fillcolor="#fef3c7", color="#f59e0b"]
data -> t1
data -> t2
data -> t3
data -> t200
t1 -> vote
t2 -> vote
t3 -> vote
t200 -> vote
}
""")
st.markdown("**How a single decision tree works:**")
st.graphviz_chart("""
digraph tree_example {
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.2,0.1"]
edge [fontname="Helvetica", fontsize=9, color="#888888", penwidth=1.5]
q1 [label="Contract = \\nMonth-to-month?", fillcolor="#e0e7ff", color="#6366f1"]
q2 [label="tenure < 12\\nmonths?", fillcolor="#e0e7ff", color="#6366f1"]
q3 [label="MonthlyCharges\\n> $70?", fillcolor="#e0e7ff", color="#6366f1"]
churn1 [label="CHURN\\n(85% probability)", fillcolor="#fee2e2", color="#ef4444"]
retain1 [label="RETAIN\\n(70% probability)", fillcolor="#d1fae5", color="#10b981"]
churn2 [label="CHURN\\n(60% probability)", fillcolor="#fee2e2", color="#ef4444"]
retain2 [label="RETAIN\\n(90% probability)", fillcolor="#d1fae5", color="#10b981"]
q1 -> q2 [label=" Yes"]
q1 -> q3 [label=" No"]
q2 -> churn1 [label=" Yes"]
q2 -> retain1 [label=" No"]
q3 -> churn2 [label=" Yes"]
q3 -> retain2 [label=" No"]
}
""")
st.markdown(
"""
**How it works:**
1. Imagine asking a series of yes/no questions: "Is the contract month-to-month?"
β "Is tenure less than 12 months?" β "Are charges above $70?"
That's a **decision tree** β it keeps splitting until it reaches an answer.
2. A single tree is easy to overfit (it memorizes the training data too closely).
So Random Forest builds **200 trees**, each trained on a **random sample**
of the data and a **random subset** of features.
3. For a new customer, all 200 trees make their prediction and we take
the **majority vote**.
**Why "Random"?** Each tree only sees a random portion of the data and features.
This diversity prevents the forest from over-relying on any single pattern.
**Why Label Encoding is fine:** Trees split on thresholds (e.g., "is Contract < 1?").
They never multiply the encoded number β so the integer values don't introduce false relationships.
**Strengths:** Handles non-linear patterns, resistant to overfitting, works with Label Encoding.
**Weaknesses:** Slower than Logistic Regression, less interpretable (200 trees are hard to inspect by hand).
"""
)
with st.expander("Step-by-step numerical walkthrough β Random Forest", expanded=False):
st.markdown(
"""
**Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month,
Internet=Fiber optic, Partner=No.
---
**Step 1 β Each tree sees different data**
Random Forest creates 200 trees. Each tree gets:
- A **bootstrap sample** β random selection of ~63% of training customers (with replacement)
- A **random subset of features** at each split (e.g., 4 out of 19 features)
| Tree | Training Customers | Features Available |
|---|---|---|
| Tree 1 | Customers #12, #45, #45, #78, #102, ... | tenure, Contract, Partner, TechSupport |
| Tree 2 | Customers #3, #22, #56, #56, #89, ... | MonthlyCharges, InternetService, gender, SeniorCitizen |
| Tree 3 | Customers #7, #33, #41, #67, #67, ... | tenure, PaymentMethod, MonthlyCharges, OnlineSecurity |
| ... | ... | ... |
Notice: Customer #45 appears twice in Tree 1 (bootstrap sampling), and each
tree considers different features.
---
**Step 2 β Each tree grows by finding the best splits**
**Tree 1** (has tenure, Contract, Partner, TechSupport):
The algorithm tries every possible split and picks the one that best
separates churners from non-churners (measured by **Gini impurity**):
"""
)
st.latex(r"\text{Gini} = 1 - p_{\text{churn}}^2 - p_{\text{retain}}^2")
st.markdown(
"""
| Candidate Split | Left Group (Churn%) | Right Group (Churn%) | Gini Improvement |
|---|---|---|---|
| Contract < 1 (Month-to-month) | 42% churn | 12% churn | **0.18** β best |
| tenure < 12 | 38% churn | 20% churn | 0.14 |
| Partner = 0 | 30% churn | 24% churn | 0.03 |
Contract split wins β becomes the first question.
Then each branch splits again on the remaining features, creating deeper
questions. This continues until leaves are pure or a depth limit is reached.
---
**Step 3 β Each tree predicts independently**
For our customer (Month-to-month, tenure=3, Fiber optic, no Partner):
| Tree | Path through the tree | Prediction |
|---|---|---|
| Tree 1 | Contract=M-t-m β tenure<12 β **CHURN** | Churn (92%) |
| Tree 2 | Charges>$70 β Fiber optic β **CHURN** | Churn (78%) |
| Tree 3 | tenure<6 β Charges>$60 β **CHURN** | Churn (85%) |
| Tree 4 | Contract=M-t-m β No TechSupport β **CHURN** | Churn (88%) |
| Tree 5 | Fiber optic β tenure<24 β **RETAIN** | Retain (55%) |
| ... | ... | ... |
---
**Step 4 β Majority vote**
Out of 200 trees:
- 172 trees predict **Churn**
- 28 trees predict **Retain**
`Probability = 172/200 = 86%` β Predict **Churn**
Note: Tree 5 predicted Retain β that's fine. The diversity is intentional.
If every tree agreed perfectly, there would be no benefit from having 200 of them.
---
**Why Random Forest handles non-linear patterns:**
A single tree can capture "Contract=Month-to-month AND tenure<12 β high churn"
without needing to encode this interaction explicitly.
The forest combines 200 different views of the data, each capturing
different interaction patterns. The majority vote smooths out individual
tree errors β this is why forests rarely overfit.
"""
)
# ββ XGBoost ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### 4. XGBoost (Extreme Gradient Boosting)")
st.markdown("*The algorithm that wins most Kaggle competitions.*")
st.graphviz_chart("""
digraph xgb {
rankdir=TB
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
edge [color="#888888", penwidth=1.5]
data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"]
t1 [label="Tree 1\\nLearns the main pattern\\n(e.g., contract type)", fillcolor="#d1fae5", color="#10b981"]
e1 [label="Errors from Tree 1\\n(customers it got wrong)", fillcolor="#fee2e2", color="#ef4444"]
t2 [label="Tree 2\\nFocuses on Tree 1's mistakes", fillcolor="#d1fae5", color="#10b981"]
e2 [label="Remaining errors", fillcolor="#fee2e2", color="#ef4444"]
t3 [label="Tree 3\\nFocuses on remaining mistakes", fillcolor="#d1fae5", color="#10b981"]
dots [label="... (200 trees total, each fixing\\nthe previous tree's errors)", shape=plaintext, fontname="Helvetica"]
final [label="Final Prediction\\nSum of all 200 trees\\n(each weighted by learning rate)", fillcolor="#fef3c7", color="#f59e0b"]
data -> t1 -> e1 -> t2 -> e2 -> t3 -> dots -> final
}
""")
st.markdown(
"""
**How it works:**
1. **Tree 1** tries to predict churn for all customers. It gets some right, some wrong.
2. **Tree 2** doesn't start from scratch β it specifically focuses on the customers
that Tree 1 got **wrong**. It tries to correct those mistakes.
3. **Tree 3** focuses on the remaining mistakes from Tree 1 + Tree 2 combined.
4. This continues for 200 rounds. Each new tree is a specialist in fixing
what all previous trees couldn't get right.
5. The final prediction is the **weighted sum** of all 200 trees.
**The key difference from Random Forest:**
- Random Forest: 200 trees trained **independently** (in parallel), then vote.
- XGBoost: 200 trees trained **sequentially**, each one learning from the previous one's errors.
**Why "Gradient Boosting"?** "Gradient" refers to the mathematical technique used to
determine *how* each new tree should focus on errors. It's the same gradient descent
concept used in deep learning.
**Strengths:** Usually the most accurate model, handles complex non-linear patterns,
built-in regularization prevents overfitting.
**Weaknesses:** Slower to train, harder to interpret, more hyperparameters to tune.
"""
)
with st.expander("Step-by-step numerical walkthrough β XGBoost", expanded=False):
st.markdown(
"""
**Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month,
Internet=Fiber optic, Partner=No.
---
**Step 1 β Start with a base prediction**
Before any trees, XGBoost starts with the global average:
`base prediction = overall churn rate = 0.265` (26.5%)
Convert to log-odds:
"""
)
st.latex(r"\text{log-odds} = \ln\left(\frac{0.265}{1 - 0.265}\right) = \ln(0.361) = -1.02")
st.markdown(
"""
Every customer starts at -1.02 (26.5% churn probability).
---
**Step 2 β Tree 1 learns from the errors**
For each customer, calculate the **residual** (how wrong the base prediction is):
| Customer | Actual | Current Prediction | Residual |
|---|---|---|---|
| #1 | Churned (1) | 0.265 | +0.735 (prediction too low) |
| #2 | Retained (0) | 0.265 | -0.265 (prediction too high) |
| #3 | Churned (1) | 0.265 | +0.735 |
| Our customer | Churned (1) | 0.265 | +0.735 |
Tree 1 tries to predict these residuals (not the original labels).
Tree 1 output for our customer: **+0.45** (it learned a partial correction).
---
**Step 3 β Update prediction with learning rate**
XGBoost uses a **learning rate** (Ξ· = 0.1 by default) to prevent overshooting:
`new log-odds = -1.02 + (0.1 Γ 0.45) = -1.02 + 0.045 = -0.975`
"""
)
st.latex(r"P = \frac{1}{1 + e^{-(-0.975)}} = \frac{1}{1 + 2.65} = 0.274")
st.markdown(
"""
Probability moved from 26.5% to 27.4%. A small step in the right direction.
---
**Step 4 β Tree 2 focuses on remaining errors**
New residuals (using updated predictions):
| Customer | Actual | Updated Prediction | New Residual |
|---|---|---|---|
| Our customer | 1 | 0.274 | +0.726 (still too low) |
Tree 2 learns these new residuals. Output for our customer: **+0.52**
`log-odds = -0.975 + (0.1 Γ 0.52) = -0.975 + 0.052 = -0.923`
`P = 28.4%`
---
**Step 5 β Continue for 200 trees**
| After Tree | Log-odds | Churn Probability |
|---|---|---|
| Base (no trees) | -1.02 | 26.5% |
| Tree 1 | -0.975 | 27.4% |
| Tree 2 | -0.923 | 28.4% |
| Tree 10 | -0.42 | 39.7% |
| Tree 50 | 0.85 | 70.1% |
| Tree 100 | 1.72 | 84.8% |
| Tree 200 | 2.95 | 95.0% |
Each tree adds a small correction. After 200 trees, the probability
climbed from 26.5% to 95.0%.
---
**The final prediction formula:**
"""
)
st.latex(r"\text{prediction} = \text{base} + \eta \cdot f_1(x) + \eta \cdot f_2(x) + \ldots + \eta \cdot f_{200}(x)")
st.markdown(
"""
Where each f(x) is a small decision tree focused on the remaining errors.
---
**Key differences from Random Forest:**
| | Random Forest | XGBoost |
|---|---|---|
| **Trees learn** | Independently (in parallel) | Sequentially (each from previous errors) |
| **Each tree predicts** | The original label (Churn/Retain) | The **residual error** from all previous trees |
| **Learning rate** | Not applicable | Controls step size (Ξ·=0.1 is typical) |
| **Why it helps** | Diversity from random sampling | Each tree is a specialist in remaining errors |
| **Risk** | Hard to overfit | Can overfit if too many trees or learning rate too high |
The learning rate is crucial: without it (Ξ·=1.0), early trees would over-correct
and the model would overfit. Small steps (Ξ·=0.1) force the model to build
gradually, which produces better generalization.
"""
)
# ββ SHAP βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### 4. SHAP β How We Explain Predictions")
st.markdown("*Making the black box transparent.*")
st.graphviz_chart("""
digraph shap {
rankdir=LR
node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
edge [color="#888888", penwidth=1.5]
base [label="Base Rate\\n26.5% churn\\n(average customer)", fillcolor="#e0e7ff", color="#6366f1"]
f1 [label="Contract =\\nMonth-to-month\\n+18%", fillcolor="#fee2e2", color="#ef4444"]
f2 [label="tenure =\\n2 months\\n+12%", fillcolor="#fee2e2", color="#ef4444"]
f3 [label="TechSupport =\\nNo\\n+5%", fillcolor="#fee2e2", color="#ef4444"]
f4 [label="TotalCharges =\\n$150 (low)\\n+3%", fillcolor="#fee2e2", color="#ef4444"]
f5 [label="Partner = Yes\\n-4%", fillcolor="#d1fae5", color="#10b981"]
pred [label="Final Prediction\\n60.5% churn", fillcolor="#fef3c7", color="#f59e0b"]
base -> f1 -> f2 -> f3 -> f4 -> f5 -> pred
}
""")
st.markdown(
"""
**How SHAP works:**
Every prediction starts from the **base rate** β the overall churn rate in the data (~26.5%).
Then SHAP shows how each feature **pushes** that prediction up or down:
- Month-to-month contract β pushes probability **up** (toward churn)
- Low tenure β pushes probability **up** (new customer, high risk)
- Having a partner β pushes probability **down** (slightly protective)
- Each feature gets a + or - contribution, and they all add up to the final prediction.
**Why this matters for business:** If the model predicts an 80% churn probability,
SHAP tells you *why* β "It's mainly because they're on a month-to-month contract
and only been with us for 2 months." That's actionable: offer them a yearly contract
with a discount.
**The name "SHAP"** stands for SHapley Additive exPlanations, based on Shapley values
from game theory β a mathematically rigorous way to fairly distribute credit among features.
"""
)
with st.expander("Step-by-step numerical walkthrough β SHAP values", expanded=False):
st.markdown(
"""
SHAP uses a concept from game theory: **Shapley values**. The idea is to
fairly distribute credit among "players" (features) for the "payout" (prediction).
**The question:** XGBoost predicts 95% churn for our customer. But *which features
caused this?* And by how much?
---
**Step 1 β Start from the base value**
The base value is the average model output across all training customers:
`base value = 26.5%` (overall churn rate)
---
**Step 2 β Measure each feature's marginal contribution**
For each feature, SHAP asks: "If I remove this feature, how much does the
prediction change?" But it's more sophisticated β it checks this across
**every possible combination** of other features.
**Example for "Contract = Month-to-month":**
| Features included (besides Contract) | Prediction with Contract | Prediction without | Contribution of Contract |
|---|---|---|---|
| None | 45% | 26.5% | +18.5% |
| tenure only | 52% | 35% | +17% |
| tenure + MonthlyCharges | 68% | 48% | +20% |
| tenure + MonthlyCharges + Internet | 82% | 60% | +22% |
| All features | 95% | 73% | +22% |
The SHAP value for Contract = **average of all marginal contributions**
= (18.5 + 17 + 20 + 22 + 22) / 5 = **+19.9%**
---
**Step 3 β Do this for every feature**
| Feature | SHAP Value | Direction |
|---|---|---|
| Contract = Month-to-month | +19.9% | Pushes toward churn |
| tenure = 3 months | +15.2% | Pushes toward churn |
| InternetService = Fiber optic | +11.8% | Pushes toward churn |
| MonthlyCharges = $85 | +8.4% | Pushes toward churn |
| TechSupport = No | +6.1% | Pushes toward churn |
| OnlineSecurity = No | +4.2% | Pushes toward churn |
| Partner = No | +2.1% | Pushes toward churn |
| PaymentMethod = Electronic check | +1.3% | Pushes toward churn |
| Other features | -0.5% | Small protective effects |
---
**Step 4 β Everything adds up to the prediction**
`26.5% (base) + 19.9% + 15.2% + 11.8% + 8.4% + 6.1% + 4.2% + 2.1% + 1.3% - 0.5% = 95.0%`
This is guaranteed by the Shapley value properties β the contributions
always sum exactly to the difference between the base value and the prediction.
---
**Why Shapley values are "fair":**
Shapley values are the *only* method that satisfies all four fairness properties:
1. **Efficiency** β contributions sum to the total prediction
2. **Symmetry** β features with identical effects get identical values
3. **Null player** β irrelevant features get zero contribution
4. **Linearity** β the method is consistent across combined models
This is why SHAP is preferred over simpler methods like "feature importance" β
it's mathematically guaranteed to distribute credit fairly.
"""
)
# ββ Comparison Table βββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### Quick Comparison")
st.markdown(
"""
| | Logistic Regression | Random Forest | XGBoost |
|---|---|---|---|
| **How it learns** | Finds the best weights for a linear equation | Builds many independent trees and averages their votes | Builds trees sequentially, each correcting the last |
| **Encoding** | One-Hot (needs binary columns) | Label (integers are fine) | Label (integers are fine) |
| **Speed** | Very fast | Moderate | Slower |
| **Accuracy** | Good baseline | Very good | Usually best |
| **Interpretability** | High (weights = feature importance) | Medium (feature importance available) | Medium (needs SHAP for full explanation) |
| **Best for** | Simple, linear relationships | Non-linear patterns with moderate data | Complex patterns, competitions, production systems |
"""
)
# ββ Why Each Model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
st.markdown("---")
st.markdown("### Why We Chose Each Model")
st.markdown(
"""
#### Logistic Regression β The Baseline
| | |
|---|---|
| **Why we included it** | Every ML project needs a simple baseline to compare against. If a complex model can't beat Logistic Regression, the complexity isn't justified. |
| **Where it's used in industry** | Credit scoring (banks are required to use interpretable models), medical diagnosis screening, spam detection, ad click prediction at scale. |
| **Key benefit** | Full transparency β every feature gets a weight you can show to stakeholders. "Month-to-month contract increases churn probability by X%." |
| **Limitation** | Assumes each feature contributes independently. Can't learn "fiber optic is only risky when tenure is short" without manually creating that interaction feature. |
| **When to pick this** | When you need to explain *why* to regulators or non-technical stakeholders, or when you need a fast model that runs on millions of rows per second. |
---
#### Naive Bayes β The Probabilistic Baseline
| | |
|---|---|
| **Why we included it** | Provides a fundamentally different approach based on probability theory. Makes the independence assumption explicit, which helps identify when features actually interact. |
| **Where it's used in industry** | Email spam filtering (Gmail), document classification, real-time sentiment analysis, medical diagnosis screening. |
| **Key benefit** | Extremely fast training and prediction. Works well with small datasets. Naturally outputs calibrated probabilities. |
| **Limitation** | The independence assumption is often violated (e.g., contract type and monthly charges are correlated). Can't capture feature interactions. |
| **When to pick this** | When speed matters more than accuracy, or as a baseline to test if feature interactions are important (if NB matches complex models, interactions don't matter). |
---
#### Random Forest β The Reliable Workhorse
| | |
|---|---|
| **Why we included it** | It captures non-linear patterns and feature interactions that Logistic Regression and Naive Bayes miss, without needing manual feature engineering. |
| **Where it's used in industry** | Fraud detection (PayPal), customer segmentation, manufacturing defect prediction, insurance risk assessment. |
| **Key benefit** | Robust to outliers and noise. Rarely overfits. Handles mixed feature types without scaling or encoding concerns. |
| **Limitation** | Slower inference than Logistic Regression or Naive Bayes. The 200-tree ensemble is harder to explain to non-technical audiences. |
| **When to pick this** | When you need reliable accuracy with minimal tuning. It's the "safe choice" β almost always performs well without surprises. |
---
#### XGBoost β The Performance Leader
| | |
|---|---|
| **Why we included it** | It's consistently the top-performing algorithm for structured/tabular data like our customer dataset. |
| **Where it's used in industry** | Won most Kaggle competitions (2015βpresent). Used at Airbnb (search ranking), Uber (ETA prediction), major banks (credit risk). |
| **Key benefit** | Sequential error correction means it learns from its own mistakes. Built-in regularization prevents overfitting. |
| **Limitation** | More hyperparameters to tune. Slower to train than Random Forest. Requires SHAP for interpretability. |
| **When to pick this** | When predictive accuracy is the priority and you have time to tune. For production churn models, XGBoost is typically the default choice. |
---
#### SHAP β The Explainability Layer
| | |
|---|---|
| **Why we included it** | A prediction without an explanation has no business value. SHAP makes any model interpretable. |
| **Where it's used in industry** | Required by EU regulations (GDPR "right to explanation"). Used at all major banks for credit decisions. Standard at tech companies for model debugging. |
| **Key benefit** | Works with *any* model. Gives both global importance (which features matter overall) and local explanations (why *this* customer was flagged). |
| **When to use** | Always. SHAP should be part of every production ML system. |
---
#### How They Complement Each Other
| Role | Model | Why |
|---|---|---|
| **Linear baseline** | Logistic Regression | Sets the floor β any model must beat this to justify its complexity |
| **Probabilistic baseline** | Naive Bayes | Tests if feature independence assumption holds. If NB performs poorly, we know interactions matter. |
| **Production model** | XGBoost | Highest accuracy for deployment |
| **Backup / ensemble** | Random Forest | If XGBoost overfits on new data, RF is a stable alternative |
| **Explainability** | SHAP on XGBoost | Turns the best model into something stakeholders can act on |
| **Live updates** | SGDClassifier (Live page) | The only model supporting `partial_fit` for incremental learning |
In a real deployment, you would use XGBoost as the primary model with SHAP
for explanations, keep Random Forest as a monitoring baseline, compare against
Naive Bayes to validate that feature interactions are being captured, and use the
SGDClassifier pattern for real-time drift adaptation.
---
#### Why Logistic Regression Can Match or Beat Complex Models on This Dataset
If Logistic Regression performs as well as XGBoost here, that's an important
finding β not a flaw.
| Factor | Why It Helps LR |
|---|---|
| **Small dataset (~5,600 training rows)** | XGBoost has far more parameters (200 trees Γ many splits) than LR (~20 weights). With limited data, complex models risk overfitting. LR's simplicity is an advantage. |
| **Linear churn signal** | Contract type, tenure, and charges have straightforward relationships with churn. These are essentially linear effects β exactly what LR is designed for. |
| **One-hot encoding** | LR receives one-hot encoded features (each category as its own binary column). This avoids false ordinal relationships that label encoding introduces for linear models. |
| **Strong regularization** | LR's built-in L2 regularization prevents any single feature from dominating, keeping the model stable. |
**What this means in practice:** A simpler, fully interpretable model achieving
top accuracy is a **win**. Every prediction can be explained to the retention team,
the model is faster to deploy, and regulatory compliance is easier.
**When would XGBoost pull ahead?** With a much larger dataset (100K+ customers)
and richer features (browsing behavior, support call logs, time-series usage patterns),
XGBoost would exploit non-linear interactions that LR can't capture β like
"fiber optic is only risky for short-tenure customers without tech support."
"""
)
# ββ Tab 0: Predict and Compare βββββββββββββββββββββββββββββββββββββββββββββββ
with tab_predict:
st.subheader("Predict and Compare β Live Demo")
st.markdown(
"Below are **5 real customers** from the test set (data the models have never seen during training). "
"Click **Run Predictions** to see what each model thinks β then compare against what actually happened."
)
raw_df_pred = load_raw_data()
churned_idx = y_test[y_test == 1].index[:3]
retained_idx = y_test[y_test == 0].index[:2]
demo_idx = churned_idx.tolist() + retained_idx.tolist()
display_columns = [
"customerID", "gender", "SeniorCitizen", "tenure", "Contract",
"InternetService", "MonthlyCharges", "TotalCharges",
]
demo_display = raw_df_pred.loc[demo_idx, display_columns].copy()
demo_display.index = range(1, len(demo_display) + 1)
demo_display.index.name = "#"
st.markdown("#### Customer Profiles")
st.dataframe(demo_display, use_container_width=True)
if st.button("Run Predictions", type="primary", use_container_width=True):
st.markdown("---")
st.markdown("#### Prediction Results")
model_names = list(all_models.keys())
# Display column headers
header_cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5])
header_cols[0].markdown("**#**")
header_cols[1].markdown("**Customer ID**")
for j, mn in enumerate(model_names):
header_cols[j + 2].markdown(f"**{mn}**")
header_cols[len(model_names) + 2].markdown("**Actual**")
header_cols[len(model_names) + 3].markdown("**β**")
st.markdown("---")
results_rows = []
for pos, idx in enumerate(demo_idx, 1):
actual_val = y_test.loc[idx]
actual_label = "Churned" if actual_val == 1 else "Retained"
row_result = {
"#": pos,
"Customer ID": raw_df_pred.loc[idx, "customerID"],
}
all_correct = True
all_wrong = True
for model_name in model_names:
model_obj = all_models[model_name]
x_data = model_test_data[model_name]
row_enc = x_data.loc[[idx]]
pred = int(model_obj.predict(row_enc)[0])
proba = float(model_obj.predict_proba(row_enc)[0][1])
pred_label = "Churned" if pred == 1 else "Retained"
correct = bool(pred == actual_val)
if correct:
all_wrong = False
else:
all_correct = False
row_result[f"{model_name}"] = f"{pred_label} ({proba:.0%})"
row_result[f"{model_name}_correct"] = correct
row_result["Actual"] = actual_label
if all_correct:
row_result["_status"] = "all_correct"
elif all_wrong:
row_result["_status"] = "all_wrong"
else:
row_result["_status"] = "mixed"
results_rows.append(row_result)
results_df = pd.DataFrame(results_rows)
for _, row in results_df.iterrows():
status = row["_status"]
if status == "all_correct":
icon = "β
"
elif status == "all_wrong":
icon = "β"
else:
icon = "β οΈ"
cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5])
cols[0].markdown(f"**{row['#']}**")
cols[1].markdown(f"`{row['Customer ID']}`")
for j, mn in enumerate(model_names):
correct = row[f"{mn}_correct"]
mark = "β" if correct else "β"
cols[j + 2].markdown(
f"{'π’' if correct else 'π΄'} {row[mn]} {mark}"
)
cols[len(model_names) + 2].markdown(f"**{row['Actual']}**")
cols[len(model_names) + 3].markdown(icon)
st.markdown("---")
n_total = len(results_rows)
n_all_correct = sum(1 for r in results_rows if r["_status"] == "all_correct")
n_all_wrong = sum(1 for r in results_rows if r["_status"] == "all_wrong")
n_mixed = n_total - n_all_correct - n_all_wrong
rc1, rc2, rc3 = st.columns(3)
rc1.metric("All Models Correct", f"{n_all_correct} / {n_total}", delta="β
")
rc2.metric("Mixed Results", f"{n_mixed} / {n_total}", delta="β οΈ" if n_mixed > 0 else None)
rc3.metric("All Models Wrong", f"{n_all_wrong} / {n_total}", delta="β" if n_all_wrong > 0 else None, delta_color="inverse")
st.markdown(
"**Legend:** π’ = correct prediction, π΄ = wrong prediction. "
"Percentage shown is the model's estimated churn probability."
)
# ββ Tab 1: Model Comparison ββββββββββββββββββββββββββββββββββββββββββββββββββ
with tab_compare:
st.subheader("Performance Metrics")
metrics_df = pd.DataFrame(metrics).T
metrics_df = metrics_df.round(3)
best_model = metrics_df["AUC"].idxmax()
st.info(f"Best model by AUC: **{best_model}** ({metrics_df.loc[best_model, 'AUC']:.3f})")
encoding_col = pd.Series({
"Logistic Regression": "One-Hot",
"Random Forest": "Label",
"XGBoost": "Label",
}, name="Encoding")
display_metrics = pd.concat([encoding_col, metrics_df], axis=1)
st.dataframe(
display_metrics.style.highlight_max(axis=0, subset=metrics_df.columns, color="#c6efce"),
use_container_width=True,
)
st.markdown("---")
st.subheader("ROC Curves")
roc_entries = [(name, all_models[name], model_test_data[name]) for name in all_models]
st.plotly_chart(plot_roc_curves(roc_entries, y_test), use_container_width=True)
st.markdown("---")
st.subheader("Confusion Matrices")
cm_cols = st.columns(4)
for i, name in enumerate(all_models):
with cm_cols[i]:
y_pred = all_models[name].predict(model_test_data[name])
st.plotly_chart(
plot_confusion_matrix(y_test, y_pred, title=name),
use_container_width=True,
)
# ββ Tab 2: Global SHAP ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
with tab_shap_global:
st.subheader("Global Feature Importance β XGBoost")
st.markdown("SHAP values show how much each feature pushes the prediction toward or away from churn.")
xgb_model = all_models["XGBoost"]
explainer, shap_values = get_shap_explainer(xgb_model, X_train)
fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
shap.plots.bar(shap_values, max_display=15, show=False, ax=ax_bar)
st.pyplot(fig_bar)
st.markdown("---")
st.markdown("**Beeswarm Plot** β Each dot is a customer. Color = feature value (red = high, blue = low).")
fig_bee, ax_bee = plt.subplots(figsize=(10, 8))
shap.plots.beeswarm(shap_values, max_display=15, show=False)
st.pyplot(plt.gcf())
# ββ Tab 3: Individual Explanations βββββββββββββββββββββββββββββββββββββββββββ
with tab_shap_individual:
st.subheader("Explain a Single Customer's Prediction")
raw_df = load_raw_data()
customer_ids = raw_df["customerID"].values
selected_id = st.selectbox("Select Customer ID", customer_ids[:200])
idx_in_raw = raw_df[raw_df["customerID"] == selected_id].index[0]
enc_df, _ = get_encoded_data()
if idx_in_raw in X_test.index:
row = X_test.loc[[idx_in_raw]]
actual = y_test.loc[idx_in_raw]
elif idx_in_raw in X_train.index:
row = X_train.loc[[idx_in_raw]]
actual = y_train.loc[idx_in_raw]
else:
row = enc_df.loc[[idx_in_raw], feature_cols]
actual = enc_df.loc[idx_in_raw, "Churn"]
proba = xgb_model.predict_proba(row)[0][1]
c1, c2 = st.columns(2)
c1.metric("Predicted Churn Probability", f"{proba:.1%}")
c2.metric("Actual Outcome", "Churned" if actual == 1 else "Retained")
st.markdown("**Customer Details (raw values):**")
st.dataframe(raw_df[raw_df["customerID"] == selected_id].T, use_container_width=True)
st.markdown("---")
st.markdown("**SHAP Waterfall β Why this prediction?**")
sv = get_shap_single(explainer, row)
fig_wf, ax_wf = plt.subplots(figsize=(10, 6))
shap.plots.waterfall(sv[0], max_display=12, show=False)
st.pyplot(plt.gcf())
# ββ Tab 4: What-If Predictor βββββββββββββββββββββββββββββββββββββββββββββββββ
with tab_whatif:
st.subheader("What-If Predictor")
st.markdown("Adjust customer features and see how churn probability changes in real time.")
raw_df = load_raw_data()
wi_col1, wi_col2, wi_col3 = st.columns(3)
with wi_col1:
wi_gender = st.selectbox("Gender", ["Female", "Male"], key="wi_gender")
wi_senior = st.selectbox("Senior Citizen", [0, 1], key="wi_senior")
wi_partner = st.selectbox("Partner", ["Yes", "No"], key="wi_partner")
wi_dependents = st.selectbox("Dependents", ["Yes", "No"], key="wi_dep")
wi_tenure = st.slider("Tenure (months)", 0, 72, 12, key="wi_tenure")
with wi_col2:
wi_phone = st.selectbox("Phone Service", ["Yes", "No"], key="wi_phone")
wi_multi = st.selectbox("Multiple Lines", ["No", "Yes", "No phone service"], key="wi_multi")
wi_internet = st.selectbox("Internet Service", ["DSL", "Fiber optic", "No"], key="wi_inet")
wi_security = st.selectbox("Online Security", ["Yes", "No", "No internet service"], key="wi_sec")
wi_backup = st.selectbox("Online Backup", ["Yes", "No", "No internet service"], key="wi_bak")
with wi_col3:
wi_protection = st.selectbox("Device Protection", ["Yes", "No", "No internet service"], key="wi_prot")
wi_support = st.selectbox("Tech Support", ["Yes", "No", "No internet service"], key="wi_sup")
wi_tv = st.selectbox("Streaming TV", ["Yes", "No", "No internet service"], key="wi_tv")
wi_movies = st.selectbox("Streaming Movies", ["Yes", "No", "No internet service"], key="wi_mov")
wi_contract = st.selectbox("Contract", ["Month-to-month", "One year", "Two year"], key="wi_con")
wi_col4, wi_col5 = st.columns(2)
with wi_col4:
wi_paperless = st.selectbox("Paperless Billing", ["Yes", "No"], key="wi_paper")
wi_payment = st.selectbox(
"Payment Method",
["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"],
key="wi_pay",
)
with wi_col5:
wi_monthly = st.slider("Monthly Charges ($)", 18.0, 120.0, 70.0, step=0.5, key="wi_monthly")
wi_total = st.slider("Total Charges ($)", 18.0, 9000.0, 1500.0, step=10.0, key="wi_total")
input_dict = {
"gender": wi_gender, "Partner": wi_partner, "Dependents": wi_dependents,
"PhoneService": wi_phone, "MultipleLines": wi_multi, "InternetService": wi_internet,
"OnlineSecurity": wi_security, "OnlineBackup": wi_backup,
"DeviceProtection": wi_protection, "TechSupport": wi_support,
"StreamingTV": wi_tv, "StreamingMovies": wi_movies, "Contract": wi_contract,
"PaperlessBilling": wi_paperless, "PaymentMethod": wi_payment,
}
numeric_dict = {
"SeniorCitizen": wi_senior, "tenure": wi_tenure,
"MonthlyCharges": wi_monthly, "TotalCharges": wi_total,
}
_, enc_map = get_encoded_data()
encoded_input = {}
for col, val in input_dict.items():
le = enc_map[col]
if val in le.classes_:
encoded_input[col] = le.transform([val])[0]
else:
encoded_input[col] = 0
encoded_input.update(numeric_dict)
input_row = pd.DataFrame([encoded_input])[feature_cols]
wi_proba = xgb_model.predict_proba(input_row)[0][1]
st.markdown("---")
res_col1, res_col2 = st.columns([1, 2])
with res_col1:
st.metric("Churn Probability", f"{wi_proba:.1%}")
if wi_proba < 0.3:
st.success("Low risk β customer likely to stay")
elif wi_proba < 0.6:
st.warning("Medium risk β consider retention offer")
else:
st.error("High risk β immediate intervention recommended")
with res_col2:
st.plotly_chart(plot_gauge(wi_proba), use_container_width=True)
st.markdown("---")
st.markdown("**Top Feature Drivers for This Configuration:**")
sv_wi = get_shap_single(explainer, input_row)
fig_wi, ax_wi = plt.subplots(figsize=(10, 5))
shap.plots.waterfall(sv_wi[0], max_display=10, show=False)
st.pyplot(plt.gcf())
|