File size: 56,802 Bytes
f3a6f24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
import streamlit as st
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt

from utils.data_loader import (
    load_raw_data, get_train_test, get_onehot_train_test,
    get_encoded_data, CATEGORICAL_COLS, NUMERIC_COLS,
)
from utils.models import (
    train_tree_models, train_lr_model, train_nb_model, evaluate_model,
    get_shap_explainer, get_shap_single,
)
from utils.visualizations import plot_roc_curves, plot_confusion_matrix, plot_gauge

st.set_page_config(page_title="Churn Models", page_icon="πŸ€–", layout="wide")
st.title("Churn Prediction Models")
st.markdown("---")

X_train, X_test, y_train, y_test, encoders, feature_cols = get_train_test()
X_train_oh, X_test_oh, _, _, feature_cols_oh = get_onehot_train_test()

# Initialize session state for models
if "churn_models_trained" not in st.session_state:
    st.session_state.churn_models_trained = False
    st.session_state.all_models = None
    st.session_state.model_test_data = None
    st.session_state.metrics = None

st.info(
    "**Encoding note:** Logistic Regression is trained on **One-Hot Encoded** features "
    f"({len(feature_cols_oh)} columns) while Random Forest and XGBoost use "
    f"**Label Encoding** ({len(feature_cols)} columns). "
    "Each model gets the encoding that's optimal for it."
)

if not st.session_state.churn_models_trained:
    st.warning("⚠️ Models not trained yet. Click the button below to train all models.")
    if st.button("πŸš€ Train All Models", type="primary", use_container_width=True):
        with st.spinner("Training models..."):
            tree_models = train_tree_models(X_train, y_train)
            lr_model = train_lr_model(X_train_oh, y_train)
            nb_model = train_nb_model(X_train, y_train)

            all_models = {
                "Logistic Regression": lr_model,
                "Naive Bayes": nb_model,
            }
            all_models.update(tree_models)

            model_test_data = {
                "Logistic Regression": X_test_oh,
                "Naive Bayes": X_test,
                "Random Forest": X_test,
                "XGBoost": X_test,
            }

            metrics = {}
            for name, model in all_models.items():
                metrics[name] = evaluate_model(model, model_test_data[name], y_test)

            st.session_state.all_models = all_models
            st.session_state.model_test_data = model_test_data
            st.session_state.metrics = metrics
            st.session_state.churn_models_trained = True
            st.rerun()
    st.stop()

# Models are trained - retrieve from session state
all_models = st.session_state.all_models
model_test_data = st.session_state.model_test_data
metrics = st.session_state.metrics

st.success("βœ… All models trained and ready!")

tab_how, tab_predict, tab_compare, tab_shap_global, tab_shap_individual, tab_whatif = st.tabs(
    ["How The Algorithms Work", "Predict and Compare", "Model Comparison", "Feature Importance (SHAP)", "Individual Explanations", "What-If Predictor"]
)

# ── Tab: How The Algorithms Work ─────────────────────────────────────────────
with tab_how:
    st.subheader("How Each Algorithm Works β€” A Visual Guide")
    st.markdown(
        "We use four very different algorithms to predict churn. Each approaches "
        "the problem in its own way. Understanding the differences helps explain "
        "why one model might outperform another and when predictions diverge."
    )

    # ── Logistic Regression ──────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### 1. Logistic Regression")
    st.markdown("*The simplest model β€” and often a strong baseline.*")

    st.graphviz_chart("""
        digraph lr {
            rankdir=LR
            node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"]
            edge [color="#888888", penwidth=1.5]

            features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"]
            weighted [label="Weighted Sum\\nw₁×tenure + wβ‚‚Γ—charges\\n+ w₃×contract + ...", fillcolor="#e0e7ff", color="#6366f1"]
            sigmoid  [label="Sigmoid Function\\nSquash to 0-1", fillcolor="#fce7f3", color="#ec4899"]
            output   [label="Probability\\n0.82 β†’ Churn\\n0.15 β†’ Retain", fillcolor="#d1fae5", color="#10b981"]

            features -> weighted -> sigmoid -> output
        }
    """)

    import plotly.graph_objects as go

    x_sig = np.linspace(-8, 8, 200)
    y_sig = 1 / (1 + np.exp(-x_sig))
    fig_sig = go.Figure()
    fig_sig.add_trace(go.Scatter(x=x_sig, y=y_sig, mode="lines", line=dict(color="#6366f1", width=3), name="Sigmoid"))
    fig_sig.add_hline(y=0.5, line_dash="dash", line_color="gray", annotation_text="Decision boundary (0.5)")
    fig_sig.add_vrect(x0=-8, x1=0, fillcolor="#d1fae5", opacity=0.15, annotation_text="Retain", annotation_position="bottom left")
    fig_sig.add_vrect(x0=0, x1=8, fillcolor="#fee2e2", opacity=0.15, annotation_text="Churn", annotation_position="bottom right")
    fig_sig.update_layout(
        title="The Sigmoid Curve β€” Turns Any Number into a Probability",
        xaxis_title="Weighted Sum of Features (z = w₁x₁ + wβ‚‚xβ‚‚ + ...)",
        yaxis_title="Churn Probability",
        height=350,
        yaxis=dict(range=[0, 1]),
    )
    st.plotly_chart(fig_sig, use_container_width=True)

    st.markdown(
        """
        **How it works:**
        1. Each customer feature gets a **weight** (a number the model learns). For example,
           "month-to-month contract" might get a high positive weight (pushes toward churn),
           while "long tenure" gets a negative weight (pushes away from churn).
        2. The model adds up: `weight₁ Γ— feature₁ + weightβ‚‚ Γ— featureβ‚‚ + ...`
        3. This sum could be any number (-∞ to +∞). The **sigmoid function** (the S-curve above)
           squashes it into a probability between 0 and 1.
        4. If the probability > 0.5 β†’ predict "Churn". Otherwise β†’ "Retain".

        **Why we use One-Hot Encoding for this model:** Because the weighted sum treats numbers
        at face value. If we encoded "Month-to-month"=2 and "Two year"=0, the model would
        think month-to-month is "more" of something than two-year β€” which is nonsensical.
        One-Hot avoids this by giving each category its own yes/no column.

        **Strengths:** Fast, interpretable (weights directly tell you what matters), good baseline.
        **Weaknesses:** Assumes a linear relationship between features and the log-odds of churn.
        Can't capture complex interactions (e.g., "fiber optic is only risky for short-tenure customers").
        """
    )

    with st.expander("Step-by-step numerical walkthrough β€” Logistic Regression", expanded=False):
        st.markdown(
            """
            Let's trace through exactly how the model makes one prediction.

            **A single customer:**

            | Feature | Value |
            |---|---|
            | tenure | 3 months |
            | MonthlyCharges | $85 |
            | Contract_Month-to-month | 1 (one-hot) |
            | Contract_One year | 0 |
            | Contract_Two year | 0 |
            | InternetService_Fiber optic | 1 |
            | Partner | 0 (no) |

            (After one-hot encoding and scaling)

            ---

            **Step 1 β€” The model has learned weights (one per feature)**

            | Feature | Learned Weight | Meaning |
            |---|---|---|
            | tenure (scaled) | -1.2 | Longer tenure β†’ less churn |
            | MonthlyCharges (scaled) | +0.5 | Higher charges β†’ more churn |
            | Contract_Month-to-month | +1.4 | Month-to-month β†’ high churn risk |
            | Contract_Two year | -0.9 | Two-year contract β†’ protective |
            | InternetService_Fiber optic | +0.8 | Fiber optic β†’ correlated with churn |
            | Partner | -0.3 | Having a partner β†’ slightly protective |
            | (bias term) | -0.2 | Baseline offset |

            ---

            **Step 2 β€” Compute the weighted sum (z)**
            """
        )
        st.latex(r"z = b + w_1 x_1 + w_2 x_2 + w_3 x_3 + \ldots")
        st.markdown(
            """
            For our customer (tenure scaled to -1.5 since 3 months is below average):

            `z = -0.2 + (-1.2 Γ— -1.5) + (0.5 Γ— 0.8) + (1.4 Γ— 1) + (0.8 Γ— 1) + (-0.3 Γ— 0)`
            `z = -0.2 + 1.8 + 0.4 + 1.4 + 0.8 + 0`
            `z = 4.2`

            ---

            **Step 3 β€” Apply the sigmoid function**
            """
        )
        st.latex(r"P(\text{churn}) = \frac{1}{1 + e^{-z}} = \frac{1}{1 + e^{-4.2}} = \frac{1}{1 + 0.015} = 0.985")
        st.markdown(
            """
            **Probability = 98.5%** β†’ Predict **Churn**

            ---

            **Step 4 β€” How did the model learn these weights?**

            | Epoch | What happens |
            |---|---|
            | Start | All weights = 0 (or random small numbers) |
            | Each step | Pick a customer, predict, calculate error |
            | Update | Adjust weights using gradient descent |
            | | `weight = weight - learning_rate Γ— gradient` |
            | Repeat | Until weights converge (error stops decreasing) |

            The gradient tells the model: "This weight should go up" or "down" based on
            whether increasing it would reduce the prediction error.

            After training on all ~5,600 training customers for many iterations,
            the weights settle to values that best separate churners from non-churners.

            ---

            **Why this customer scores so high:**

            | Feature | Contribution to z |
            |---|---|
            | Short tenure (-1.5 scaled) | +1.8 (weight is negative, value is negative β†’ positive contribution) |
            | Month-to-month contract | +1.4 |
            | Fiber optic | +0.8 |
            | High charges | +0.4 |
            | **Total push toward churn** | **+4.2** β†’ sigmoid β†’ 98.5% |

            Every feature contributes additively. That's the key property (and limitation)
            of logistic regression β€” it can't model interactions like "fiber optic is
            only risky for short-tenure customers."
            """
        )

    # ── Naive Bayes ──────────────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### 2. Naive Bayes (Gaussian)")
    st.markdown("*A probabilistic model that assumes features are independent.*")

    st.graphviz_chart("""
        digraph nb {
            rankdir=LR
            node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.3,0.15"]
            edge [color="#888888", penwidth=1.5]

            features [label="Customer Features\\ntenure, charges,\\ncontract type, ...", fillcolor="#dbeafe", color="#3b82f6"]
            prob [label="Calculate Probability\\nfor each class\\nusing Bayes' Theorem", fillcolor="#e0e7ff", color="#6366f1"]
            compare [label="Compare\\nP(Churn | features)\\nvs\\nP(Retain | features)", fillcolor="#fce7f3", color="#ec4899"]
            output [label="Pick Higher\\nProbability", fillcolor="#d1fae5", color="#10b981"]

            features -> prob -> compare -> output
        }
    """)

    st.markdown(
        """
        **How it works:**

        Naive Bayes uses **Bayes' Theorem** to calculate the probability of churn given the customer's features.
        The "naive" part means it assumes all features are **independent** β€” knowing tenure doesn't tell
        you anything about monthly charges (which isn't true in reality, but the assumption simplifies the math).

        **Formula:**
        """
    )

    st.latex(r"P(\text{Churn} \mid \text{features}) = \frac{P(\text{features} \mid \text{Churn}) \times P(\text{Churn})}{P(\text{features})}")

    st.markdown(
        """
        For continuous features (like tenure, charges), Gaussian Naive Bayes assumes each feature
        follows a **normal distribution** within each class.

        **Step-by-step example:**

        Suppose we want to predict if Customer X will churn. They have:
        - Tenure = 12 months
        - Monthly Charges = $70

        **Step 1:** Calculate P(Tenure=12 | Churn) and P(Tenure=12 | Retain)

        From training data, we know:
        - Churners have average tenure = 18 months, std = 15
        - Retainers have average tenure = 38 months, std = 20

        Using the Gaussian (bell curve) formula:
        - P(Tenure=12 | Churn) = 0.024 (12 is close to churner average)
        - P(Tenure=12 | Retain) = 0.008 (12 is far from retainer average)

        **Step 2:** Do the same for Monthly Charges

        - P(Charges=70 | Churn) = 0.015
        - P(Charges=70 | Retain) = 0.012

        **Step 3:** Multiply probabilities (the "naive" independence assumption)

        - P(features | Churn) = 0.024 Γ— 0.015 = 0.00036
        - P(features | Retain) = 0.008 Γ— 0.012 = 0.000096

        **Step 4:** Apply Bayes' Theorem

        - P(Churn | features) ∝ 0.00036 Γ— P(Churn) = 0.00036 Γ— 0.27 = 0.0000972
        - P(Retain | features) ∝ 0.000096 Γ— P(Retain) = 0.000096 Γ— 0.73 = 0.00007008

        Normalize: P(Churn) = 0.0000972 / (0.0000972 + 0.00007008) = **58%**

        **Prediction: Churn (58% probability)**

        **Why Naive Bayes is different:**
        - Makes strong independence assumption (features don't interact)
        - Very fast to train and predict
        - Works well when features are actually somewhat independent
        - Often gives different predictions than tree-based models because it doesn't
          capture feature interactions (e.g., "fiber optic + short tenure" combo)
        """
    )

    # ── Random Forest ────────────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### 3. Random Forest")
    st.markdown("*An ensemble of decision trees that vote together.*")

    st.graphviz_chart("""
        digraph rf {
            node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
            edge [color="#888888", penwidth=1.5]

            data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"]

            subgraph cluster_trees {
                label="200 Decision Trees (each sees a random subset)"
                style=dashed
                color="#94a3b8"
                fontname="Helvetica"
                fontsize=10

                t1 [label="Tree 1\\n→ Churn", fillcolor="#d1fae5", color="#10b981"]
                t2 [label="Tree 2\\n→ Retain", fillcolor="#d1fae5", color="#10b981"]
                t3 [label="Tree 3\\n→ Churn", fillcolor="#d1fae5", color="#10b981"]
                dots [label="...\\n(197 more)", shape=plaintext]
                t200 [label="Tree 200\\n→ Churn", fillcolor="#d1fae5", color="#10b981"]
            }

            vote [label="Majority Vote\\n3 out of 4 say Churn\\n→ Final: Churn (75%)", fillcolor="#fef3c7", color="#f59e0b"]

            data -> t1
            data -> t2
            data -> t3
            data -> t200
            t1 -> vote
            t2 -> vote
            t3 -> vote
            t200 -> vote
        }
    """)

    st.markdown("**How a single decision tree works:**")
    st.graphviz_chart("""
        digraph tree_example {
            node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.2,0.1"]
            edge [fontname="Helvetica", fontsize=9, color="#888888", penwidth=1.5]

            q1 [label="Contract = \\nMonth-to-month?", fillcolor="#e0e7ff", color="#6366f1"]
            q2 [label="tenure < 12\\nmonths?", fillcolor="#e0e7ff", color="#6366f1"]
            q3 [label="MonthlyCharges\\n> $70?", fillcolor="#e0e7ff", color="#6366f1"]
            churn1 [label="CHURN\\n(85% probability)", fillcolor="#fee2e2", color="#ef4444"]
            retain1 [label="RETAIN\\n(70% probability)", fillcolor="#d1fae5", color="#10b981"]
            churn2 [label="CHURN\\n(60% probability)", fillcolor="#fee2e2", color="#ef4444"]
            retain2 [label="RETAIN\\n(90% probability)", fillcolor="#d1fae5", color="#10b981"]

            q1 -> q2 [label=" Yes"]
            q1 -> q3 [label=" No"]
            q2 -> churn1 [label=" Yes"]
            q2 -> retain1 [label=" No"]
            q3 -> churn2 [label=" Yes"]
            q3 -> retain2 [label=" No"]
        }
    """)

    st.markdown(
        """
        **How it works:**
        1. Imagine asking a series of yes/no questions: "Is the contract month-to-month?"
           β†’ "Is tenure less than 12 months?" β†’ "Are charges above $70?"
           That's a **decision tree** β€” it keeps splitting until it reaches an answer.
        2. A single tree is easy to overfit (it memorizes the training data too closely).
           So Random Forest builds **200 trees**, each trained on a **random sample**
           of the data and a **random subset** of features.
        3. For a new customer, all 200 trees make their prediction and we take
           the **majority vote**.

        **Why "Random"?** Each tree only sees a random portion of the data and features.
        This diversity prevents the forest from over-relying on any single pattern.

        **Why Label Encoding is fine:** Trees split on thresholds (e.g., "is Contract < 1?").
        They never multiply the encoded number β€” so the integer values don't introduce false relationships.

        **Strengths:** Handles non-linear patterns, resistant to overfitting, works with Label Encoding.
        **Weaknesses:** Slower than Logistic Regression, less interpretable (200 trees are hard to inspect by hand).
        """
    )

    with st.expander("Step-by-step numerical walkthrough β€” Random Forest", expanded=False):
        st.markdown(
            """
            **Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month,
            Internet=Fiber optic, Partner=No.

            ---

            **Step 1 β€” Each tree sees different data**

            Random Forest creates 200 trees. Each tree gets:
            - A **bootstrap sample** β€” random selection of ~63% of training customers (with replacement)
            - A **random subset of features** at each split (e.g., 4 out of 19 features)

            | Tree | Training Customers | Features Available |
            |---|---|---|
            | Tree 1 | Customers #12, #45, #45, #78, #102, ... | tenure, Contract, Partner, TechSupport |
            | Tree 2 | Customers #3, #22, #56, #56, #89, ... | MonthlyCharges, InternetService, gender, SeniorCitizen |
            | Tree 3 | Customers #7, #33, #41, #67, #67, ... | tenure, PaymentMethod, MonthlyCharges, OnlineSecurity |
            | ... | ... | ... |

            Notice: Customer #45 appears twice in Tree 1 (bootstrap sampling), and each
            tree considers different features.

            ---

            **Step 2 β€” Each tree grows by finding the best splits**

            **Tree 1** (has tenure, Contract, Partner, TechSupport):

            The algorithm tries every possible split and picks the one that best
            separates churners from non-churners (measured by **Gini impurity**):
            """
        )
        st.latex(r"\text{Gini} = 1 - p_{\text{churn}}^2 - p_{\text{retain}}^2")
        st.markdown(
            """
            | Candidate Split | Left Group (Churn%) | Right Group (Churn%) | Gini Improvement |
            |---|---|---|---|
            | Contract < 1 (Month-to-month) | 42% churn | 12% churn | **0.18** ← best |
            | tenure < 12 | 38% churn | 20% churn | 0.14 |
            | Partner = 0 | 30% churn | 24% churn | 0.03 |

            Contract split wins β†’ becomes the first question.

            Then each branch splits again on the remaining features, creating deeper
            questions. This continues until leaves are pure or a depth limit is reached.

            ---

            **Step 3 β€” Each tree predicts independently**

            For our customer (Month-to-month, tenure=3, Fiber optic, no Partner):

            | Tree | Path through the tree | Prediction |
            |---|---|---|
            | Tree 1 | Contract=M-t-m β†’ tenure<12 β†’ **CHURN** | Churn (92%) |
            | Tree 2 | Charges>$70 β†’ Fiber optic β†’ **CHURN** | Churn (78%) |
            | Tree 3 | tenure<6 β†’ Charges>$60 β†’ **CHURN** | Churn (85%) |
            | Tree 4 | Contract=M-t-m β†’ No TechSupport β†’ **CHURN** | Churn (88%) |
            | Tree 5 | Fiber optic β†’ tenure<24 β†’ **RETAIN** | Retain (55%) |
            | ... | ... | ... |

            ---

            **Step 4 β€” Majority vote**

            Out of 200 trees:
            - 172 trees predict **Churn**
            - 28 trees predict **Retain**

            `Probability = 172/200 = 86%` β†’ Predict **Churn**

            Note: Tree 5 predicted Retain β€” that's fine. The diversity is intentional.
            If every tree agreed perfectly, there would be no benefit from having 200 of them.

            ---

            **Why Random Forest handles non-linear patterns:**

            A single tree can capture "Contract=Month-to-month AND tenure<12 β†’ high churn"
            without needing to encode this interaction explicitly.

            The forest combines 200 different views of the data, each capturing
            different interaction patterns. The majority vote smooths out individual
            tree errors β€” this is why forests rarely overfit.
            """
        )

    # ── XGBoost ──────────────────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### 4. XGBoost (Extreme Gradient Boosting)")
    st.markdown("*The algorithm that wins most Kaggle competitions.*")

    st.graphviz_chart("""
        digraph xgb {
            rankdir=TB
            node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
            edge [color="#888888", penwidth=1.5]

            data [label="Training Data", fillcolor="#dbeafe", color="#3b82f6"]

            t1 [label="Tree 1\\nLearns the main pattern\\n(e.g., contract type)", fillcolor="#d1fae5", color="#10b981"]
            e1 [label="Errors from Tree 1\\n(customers it got wrong)", fillcolor="#fee2e2", color="#ef4444"]
            t2 [label="Tree 2\\nFocuses on Tree 1's mistakes", fillcolor="#d1fae5", color="#10b981"]
            e2 [label="Remaining errors", fillcolor="#fee2e2", color="#ef4444"]
            t3 [label="Tree 3\\nFocuses on remaining mistakes", fillcolor="#d1fae5", color="#10b981"]
            dots [label="... (200 trees total, each fixing\\nthe previous tree's errors)", shape=plaintext, fontname="Helvetica"]
            final [label="Final Prediction\\nSum of all 200 trees\\n(each weighted by learning rate)", fillcolor="#fef3c7", color="#f59e0b"]

            data -> t1 -> e1 -> t2 -> e2 -> t3 -> dots -> final
        }
    """)

    st.markdown(
        """
        **How it works:**
        1. **Tree 1** tries to predict churn for all customers. It gets some right, some wrong.
        2. **Tree 2** doesn't start from scratch β€” it specifically focuses on the customers
           that Tree 1 got **wrong**. It tries to correct those mistakes.
        3. **Tree 3** focuses on the remaining mistakes from Tree 1 + Tree 2 combined.
        4. This continues for 200 rounds. Each new tree is a specialist in fixing
           what all previous trees couldn't get right.
        5. The final prediction is the **weighted sum** of all 200 trees.

        **The key difference from Random Forest:**
        - Random Forest: 200 trees trained **independently** (in parallel), then vote.
        - XGBoost: 200 trees trained **sequentially**, each one learning from the previous one's errors.

        **Why "Gradient Boosting"?** "Gradient" refers to the mathematical technique used to
        determine *how* each new tree should focus on errors. It's the same gradient descent
        concept used in deep learning.

        **Strengths:** Usually the most accurate model, handles complex non-linear patterns,
        built-in regularization prevents overfitting.
        **Weaknesses:** Slower to train, harder to interpret, more hyperparameters to tune.
        """
    )

    with st.expander("Step-by-step numerical walkthrough β€” XGBoost", expanded=False):
        st.markdown(
            """
            **Same customer:** tenure=3, MonthlyCharges=$85, Contract=Month-to-month,
            Internet=Fiber optic, Partner=No.

            ---

            **Step 1 β€” Start with a base prediction**

            Before any trees, XGBoost starts with the global average:

            `base prediction = overall churn rate = 0.265` (26.5%)

            Convert to log-odds:
            """
        )
        st.latex(r"\text{log-odds} = \ln\left(\frac{0.265}{1 - 0.265}\right) = \ln(0.361) = -1.02")
        st.markdown(
            """
            Every customer starts at -1.02 (26.5% churn probability).

            ---

            **Step 2 β€” Tree 1 learns from the errors**

            For each customer, calculate the **residual** (how wrong the base prediction is):

            | Customer | Actual | Current Prediction | Residual |
            |---|---|---|---|
            | #1 | Churned (1) | 0.265 | +0.735 (prediction too low) |
            | #2 | Retained (0) | 0.265 | -0.265 (prediction too high) |
            | #3 | Churned (1) | 0.265 | +0.735 |
            | Our customer | Churned (1) | 0.265 | +0.735 |

            Tree 1 tries to predict these residuals (not the original labels).

            Tree 1 output for our customer: **+0.45** (it learned a partial correction).

            ---

            **Step 3 β€” Update prediction with learning rate**

            XGBoost uses a **learning rate** (Ξ· = 0.1 by default) to prevent overshooting:

            `new log-odds = -1.02 + (0.1 Γ— 0.45) = -1.02 + 0.045 = -0.975`
            """
        )
        st.latex(r"P = \frac{1}{1 + e^{-(-0.975)}} = \frac{1}{1 + 2.65} = 0.274")
        st.markdown(
            """
            Probability moved from 26.5% to 27.4%. A small step in the right direction.

            ---

            **Step 4 β€” Tree 2 focuses on remaining errors**

            New residuals (using updated predictions):

            | Customer | Actual | Updated Prediction | New Residual |
            |---|---|---|---|
            | Our customer | 1 | 0.274 | +0.726 (still too low) |

            Tree 2 learns these new residuals. Output for our customer: **+0.52**

            `log-odds = -0.975 + (0.1 Γ— 0.52) = -0.975 + 0.052 = -0.923`
            `P = 28.4%`

            ---

            **Step 5 β€” Continue for 200 trees**

            | After Tree | Log-odds | Churn Probability |
            |---|---|---|
            | Base (no trees) | -1.02 | 26.5% |
            | Tree 1 | -0.975 | 27.4% |
            | Tree 2 | -0.923 | 28.4% |
            | Tree 10 | -0.42 | 39.7% |
            | Tree 50 | 0.85 | 70.1% |
            | Tree 100 | 1.72 | 84.8% |
            | Tree 200 | 2.95 | 95.0% |

            Each tree adds a small correction. After 200 trees, the probability
            climbed from 26.5% to 95.0%.

            ---

            **The final prediction formula:**
            """
        )
        st.latex(r"\text{prediction} = \text{base} + \eta \cdot f_1(x) + \eta \cdot f_2(x) + \ldots + \eta \cdot f_{200}(x)")
        st.markdown(
            """
            Where each f(x) is a small decision tree focused on the remaining errors.

            ---

            **Key differences from Random Forest:**

            | | Random Forest | XGBoost |
            |---|---|---|
            | **Trees learn** | Independently (in parallel) | Sequentially (each from previous errors) |
            | **Each tree predicts** | The original label (Churn/Retain) | The **residual error** from all previous trees |
            | **Learning rate** | Not applicable | Controls step size (Ξ·=0.1 is typical) |
            | **Why it helps** | Diversity from random sampling | Each tree is a specialist in remaining errors |
            | **Risk** | Hard to overfit | Can overfit if too many trees or learning rate too high |

            The learning rate is crucial: without it (Ξ·=1.0), early trees would over-correct
            and the model would overfit. Small steps (Ξ·=0.1) force the model to build
            gradually, which produces better generalization.
            """
        )

    # ── SHAP ─────────────────────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### 4. SHAP β€” How We Explain Predictions")
    st.markdown("*Making the black box transparent.*")

    st.graphviz_chart("""
        digraph shap {
            rankdir=LR
            node [shape=box, style="rounded,filled", fontname="Helvetica", fontsize=10, margin="0.25,0.12"]
            edge [color="#888888", penwidth=1.5]

            base [label="Base Rate\\n26.5% churn\\n(average customer)", fillcolor="#e0e7ff", color="#6366f1"]
            f1 [label="Contract =\\nMonth-to-month\\n+18%", fillcolor="#fee2e2", color="#ef4444"]
            f2 [label="tenure =\\n2 months\\n+12%", fillcolor="#fee2e2", color="#ef4444"]
            f3 [label="TechSupport =\\nNo\\n+5%", fillcolor="#fee2e2", color="#ef4444"]
            f4 [label="TotalCharges =\\n$150 (low)\\n+3%", fillcolor="#fee2e2", color="#ef4444"]
            f5 [label="Partner = Yes\\n-4%", fillcolor="#d1fae5", color="#10b981"]
            pred [label="Final Prediction\\n60.5% churn", fillcolor="#fef3c7", color="#f59e0b"]

            base -> f1 -> f2 -> f3 -> f4 -> f5 -> pred
        }
    """)

    st.markdown(
        """
        **How SHAP works:**

        Every prediction starts from the **base rate** β€” the overall churn rate in the data (~26.5%).
        Then SHAP shows how each feature **pushes** that prediction up or down:

        - Month-to-month contract β†’ pushes probability **up** (toward churn)
        - Low tenure β†’ pushes probability **up** (new customer, high risk)
        - Having a partner β†’ pushes probability **down** (slightly protective)
        - Each feature gets a + or - contribution, and they all add up to the final prediction.

        **Why this matters for business:** If the model predicts an 80% churn probability,
        SHAP tells you *why* β€” "It's mainly because they're on a month-to-month contract
        and only been with us for 2 months." That's actionable: offer them a yearly contract
        with a discount.

        **The name "SHAP"** stands for SHapley Additive exPlanations, based on Shapley values
        from game theory β€” a mathematically rigorous way to fairly distribute credit among features.
        """
    )

    with st.expander("Step-by-step numerical walkthrough β€” SHAP values", expanded=False):
        st.markdown(
            """
            SHAP uses a concept from game theory: **Shapley values**. The idea is to
            fairly distribute credit among "players" (features) for the "payout" (prediction).

            **The question:** XGBoost predicts 95% churn for our customer. But *which features
            caused this?* And by how much?

            ---

            **Step 1 β€” Start from the base value**

            The base value is the average model output across all training customers:

            `base value = 26.5%` (overall churn rate)

            ---

            **Step 2 β€” Measure each feature's marginal contribution**

            For each feature, SHAP asks: "If I remove this feature, how much does the
            prediction change?" But it's more sophisticated β€” it checks this across
            **every possible combination** of other features.

            **Example for "Contract = Month-to-month":**

            | Features included (besides Contract) | Prediction with Contract | Prediction without | Contribution of Contract |
            |---|---|---|---|
            | None | 45% | 26.5% | +18.5% |
            | tenure only | 52% | 35% | +17% |
            | tenure + MonthlyCharges | 68% | 48% | +20% |
            | tenure + MonthlyCharges + Internet | 82% | 60% | +22% |
            | All features | 95% | 73% | +22% |

            The SHAP value for Contract = **average of all marginal contributions**
            = (18.5 + 17 + 20 + 22 + 22) / 5 = **+19.9%**

            ---

            **Step 3 β€” Do this for every feature**

            | Feature | SHAP Value | Direction |
            |---|---|---|
            | Contract = Month-to-month | +19.9% | Pushes toward churn |
            | tenure = 3 months | +15.2% | Pushes toward churn |
            | InternetService = Fiber optic | +11.8% | Pushes toward churn |
            | MonthlyCharges = $85 | +8.4% | Pushes toward churn |
            | TechSupport = No | +6.1% | Pushes toward churn |
            | OnlineSecurity = No | +4.2% | Pushes toward churn |
            | Partner = No | +2.1% | Pushes toward churn |
            | PaymentMethod = Electronic check | +1.3% | Pushes toward churn |
            | Other features | -0.5% | Small protective effects |

            ---

            **Step 4 β€” Everything adds up to the prediction**

            `26.5% (base) + 19.9% + 15.2% + 11.8% + 8.4% + 6.1% + 4.2% + 2.1% + 1.3% - 0.5% = 95.0%`

            This is guaranteed by the Shapley value properties β€” the contributions
            always sum exactly to the difference between the base value and the prediction.

            ---

            **Why Shapley values are "fair":**

            Shapley values are the *only* method that satisfies all four fairness properties:
            1. **Efficiency** β€” contributions sum to the total prediction
            2. **Symmetry** β€” features with identical effects get identical values
            3. **Null player** β€” irrelevant features get zero contribution
            4. **Linearity** β€” the method is consistent across combined models

            This is why SHAP is preferred over simpler methods like "feature importance" β€”
            it's mathematically guaranteed to distribute credit fairly.
            """
        )

    # ── Comparison Table ─────────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### Quick Comparison")
    st.markdown(
        """
        | | Logistic Regression | Random Forest | XGBoost |
        |---|---|---|---|
        | **How it learns** | Finds the best weights for a linear equation | Builds many independent trees and averages their votes | Builds trees sequentially, each correcting the last |
        | **Encoding** | One-Hot (needs binary columns) | Label (integers are fine) | Label (integers are fine) |
        | **Speed** | Very fast | Moderate | Slower |
        | **Accuracy** | Good baseline | Very good | Usually best |
        | **Interpretability** | High (weights = feature importance) | Medium (feature importance available) | Medium (needs SHAP for full explanation) |
        | **Best for** | Simple, linear relationships | Non-linear patterns with moderate data | Complex patterns, competitions, production systems |
        """
    )

    # ── Why Each Model ────────────────────────────────────────────────────────
    st.markdown("---")
    st.markdown("### Why We Chose Each Model")

    st.markdown(
        """
        #### Logistic Regression β€” The Baseline

        | | |
        |---|---|
        | **Why we included it** | Every ML project needs a simple baseline to compare against. If a complex model can't beat Logistic Regression, the complexity isn't justified. |
        | **Where it's used in industry** | Credit scoring (banks are required to use interpretable models), medical diagnosis screening, spam detection, ad click prediction at scale. |
        | **Key benefit** | Full transparency β€” every feature gets a weight you can show to stakeholders. "Month-to-month contract increases churn probability by X%." |
        | **Limitation** | Assumes each feature contributes independently. Can't learn "fiber optic is only risky when tenure is short" without manually creating that interaction feature. |
        | **When to pick this** | When you need to explain *why* to regulators or non-technical stakeholders, or when you need a fast model that runs on millions of rows per second. |

        ---

        #### Naive Bayes β€” The Probabilistic Baseline

        | | |
        |---|---|
        | **Why we included it** | Provides a fundamentally different approach based on probability theory. Makes the independence assumption explicit, which helps identify when features actually interact. |
        | **Where it's used in industry** | Email spam filtering (Gmail), document classification, real-time sentiment analysis, medical diagnosis screening. |
        | **Key benefit** | Extremely fast training and prediction. Works well with small datasets. Naturally outputs calibrated probabilities. |
        | **Limitation** | The independence assumption is often violated (e.g., contract type and monthly charges are correlated). Can't capture feature interactions. |
        | **When to pick this** | When speed matters more than accuracy, or as a baseline to test if feature interactions are important (if NB matches complex models, interactions don't matter). |

        ---

        #### Random Forest β€” The Reliable Workhorse

        | | |
        |---|---|
        | **Why we included it** | It captures non-linear patterns and feature interactions that Logistic Regression and Naive Bayes miss, without needing manual feature engineering. |
        | **Where it's used in industry** | Fraud detection (PayPal), customer segmentation, manufacturing defect prediction, insurance risk assessment. |
        | **Key benefit** | Robust to outliers and noise. Rarely overfits. Handles mixed feature types without scaling or encoding concerns. |
        | **Limitation** | Slower inference than Logistic Regression or Naive Bayes. The 200-tree ensemble is harder to explain to non-technical audiences. |
        | **When to pick this** | When you need reliable accuracy with minimal tuning. It's the "safe choice" β€” almost always performs well without surprises. |

        ---

        #### XGBoost β€” The Performance Leader

        | | |
        |---|---|
        | **Why we included it** | It's consistently the top-performing algorithm for structured/tabular data like our customer dataset. |
        | **Where it's used in industry** | Won most Kaggle competitions (2015–present). Used at Airbnb (search ranking), Uber (ETA prediction), major banks (credit risk). |
        | **Key benefit** | Sequential error correction means it learns from its own mistakes. Built-in regularization prevents overfitting. |
        | **Limitation** | More hyperparameters to tune. Slower to train than Random Forest. Requires SHAP for interpretability. |
        | **When to pick this** | When predictive accuracy is the priority and you have time to tune. For production churn models, XGBoost is typically the default choice. |

        ---

        #### SHAP β€” The Explainability Layer

        | | |
        |---|---|
        | **Why we included it** | A prediction without an explanation has no business value. SHAP makes any model interpretable. |
        | **Where it's used in industry** | Required by EU regulations (GDPR "right to explanation"). Used at all major banks for credit decisions. Standard at tech companies for model debugging. |
        | **Key benefit** | Works with *any* model. Gives both global importance (which features matter overall) and local explanations (why *this* customer was flagged). |
        | **When to use** | Always. SHAP should be part of every production ML system. |

        ---

        #### How They Complement Each Other

        | Role | Model | Why |
        |---|---|---|
        | **Linear baseline** | Logistic Regression | Sets the floor β€” any model must beat this to justify its complexity |
        | **Probabilistic baseline** | Naive Bayes | Tests if feature independence assumption holds. If NB performs poorly, we know interactions matter. |
        | **Production model** | XGBoost | Highest accuracy for deployment |
        | **Backup / ensemble** | Random Forest | If XGBoost overfits on new data, RF is a stable alternative |
        | **Explainability** | SHAP on XGBoost | Turns the best model into something stakeholders can act on |
        | **Live updates** | SGDClassifier (Live page) | The only model supporting `partial_fit` for incremental learning |

        In a real deployment, you would use XGBoost as the primary model with SHAP
        for explanations, keep Random Forest as a monitoring baseline, compare against
        Naive Bayes to validate that feature interactions are being captured, and use the
        SGDClassifier pattern for real-time drift adaptation.

        ---

        #### Why Logistic Regression Can Match or Beat Complex Models on This Dataset

        If Logistic Regression performs as well as XGBoost here, that's an important
        finding β€” not a flaw.

        | Factor | Why It Helps LR |
        |---|---|
        | **Small dataset (~5,600 training rows)** | XGBoost has far more parameters (200 trees Γ— many splits) than LR (~20 weights). With limited data, complex models risk overfitting. LR's simplicity is an advantage. |
        | **Linear churn signal** | Contract type, tenure, and charges have straightforward relationships with churn. These are essentially linear effects β€” exactly what LR is designed for. |
        | **One-hot encoding** | LR receives one-hot encoded features (each category as its own binary column). This avoids false ordinal relationships that label encoding introduces for linear models. |
        | **Strong regularization** | LR's built-in L2 regularization prevents any single feature from dominating, keeping the model stable. |

        **What this means in practice:** A simpler, fully interpretable model achieving
        top accuracy is a **win**. Every prediction can be explained to the retention team,
        the model is faster to deploy, and regulatory compliance is easier.

        **When would XGBoost pull ahead?** With a much larger dataset (100K+ customers)
        and richer features (browsing behavior, support call logs, time-series usage patterns),
        XGBoost would exploit non-linear interactions that LR can't capture β€” like
        "fiber optic is only risky for short-tenure customers without tech support."
        """
    )

# ── Tab 0: Predict and Compare ───────────────────────────────────────────────
with tab_predict:
    st.subheader("Predict and Compare β€” Live Demo")
    st.markdown(
        "Below are **5 real customers** from the test set (data the models have never seen during training). "
        "Click **Run Predictions** to see what each model thinks β€” then compare against what actually happened."
    )

    raw_df_pred = load_raw_data()

    churned_idx = y_test[y_test == 1].index[:3]
    retained_idx = y_test[y_test == 0].index[:2]
    demo_idx = churned_idx.tolist() + retained_idx.tolist()

    display_columns = [
        "customerID", "gender", "SeniorCitizen", "tenure", "Contract",
        "InternetService", "MonthlyCharges", "TotalCharges",
    ]
    demo_display = raw_df_pred.loc[demo_idx, display_columns].copy()
    demo_display.index = range(1, len(demo_display) + 1)
    demo_display.index.name = "#"

    st.markdown("#### Customer Profiles")
    st.dataframe(demo_display, use_container_width=True)

    if st.button("Run Predictions", type="primary", use_container_width=True):
        st.markdown("---")
        st.markdown("#### Prediction Results")

        model_names = list(all_models.keys())
        
        # Display column headers
        header_cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5])
        header_cols[0].markdown("**#**")
        header_cols[1].markdown("**Customer ID**")
        for j, mn in enumerate(model_names):
            header_cols[j + 2].markdown(f"**{mn}**")
        header_cols[len(model_names) + 2].markdown("**Actual**")
        header_cols[len(model_names) + 3].markdown("**βœ“**")
        
        st.markdown("---")
        
        results_rows = []
        for pos, idx in enumerate(demo_idx, 1):
            actual_val = y_test.loc[idx]
            actual_label = "Churned" if actual_val == 1 else "Retained"

            row_result = {
                "#": pos,
                "Customer ID": raw_df_pred.loc[idx, "customerID"],
            }

            all_correct = True
            all_wrong = True
            for model_name in model_names:
                model_obj = all_models[model_name]
                x_data = model_test_data[model_name]
                row_enc = x_data.loc[[idx]]
                pred = int(model_obj.predict(row_enc)[0])
                proba = float(model_obj.predict_proba(row_enc)[0][1])
                pred_label = "Churned" if pred == 1 else "Retained"
                correct = bool(pred == actual_val)
                if correct:
                    all_wrong = False
                else:
                    all_correct = False
                row_result[f"{model_name}"] = f"{pred_label} ({proba:.0%})"
                row_result[f"{model_name}_correct"] = correct

            row_result["Actual"] = actual_label

            if all_correct:
                row_result["_status"] = "all_correct"
            elif all_wrong:
                row_result["_status"] = "all_wrong"
            else:
                row_result["_status"] = "mixed"

            results_rows.append(row_result)

        results_df = pd.DataFrame(results_rows)

        for _, row in results_df.iterrows():
            status = row["_status"]
            if status == "all_correct":
                icon = "βœ…"
            elif status == "all_wrong":
                icon = "❌"
            else:
                icon = "⚠️"

            cols = st.columns([0.5, 1.5] + [2] * len(model_names) + [1.2, 0.5])
            cols[0].markdown(f"**{row['#']}**")
            cols[1].markdown(f"`{row['Customer ID']}`")
            for j, mn in enumerate(model_names):
                correct = row[f"{mn}_correct"]
                mark = "βœ“" if correct else "βœ—"
                cols[j + 2].markdown(
                    f"{'🟒' if correct else 'πŸ”΄'} {row[mn]} {mark}"
                )
            cols[len(model_names) + 2].markdown(f"**{row['Actual']}**")
            cols[len(model_names) + 3].markdown(icon)

        st.markdown("---")

        n_total = len(results_rows)
        n_all_correct = sum(1 for r in results_rows if r["_status"] == "all_correct")
        n_all_wrong = sum(1 for r in results_rows if r["_status"] == "all_wrong")
        n_mixed = n_total - n_all_correct - n_all_wrong

        rc1, rc2, rc3 = st.columns(3)
        rc1.metric("All Models Correct", f"{n_all_correct} / {n_total}", delta="βœ…")
        rc2.metric("Mixed Results", f"{n_mixed} / {n_total}", delta="⚠️" if n_mixed > 0 else None)
        rc3.metric("All Models Wrong", f"{n_all_wrong} / {n_total}", delta="❌" if n_all_wrong > 0 else None, delta_color="inverse")

        st.markdown(
            "**Legend:** 🟒 = correct prediction, πŸ”΄ = wrong prediction. "
            "Percentage shown is the model's estimated churn probability."
        )

# ── Tab 1: Model Comparison ──────────────────────────────────────────────────
with tab_compare:
    st.subheader("Performance Metrics")
    metrics_df = pd.DataFrame(metrics).T
    metrics_df = metrics_df.round(3)

    best_model = metrics_df["AUC"].idxmax()
    st.info(f"Best model by AUC: **{best_model}** ({metrics_df.loc[best_model, 'AUC']:.3f})")

    encoding_col = pd.Series({
        "Logistic Regression": "One-Hot",
        "Random Forest": "Label",
        "XGBoost": "Label",
    }, name="Encoding")
    display_metrics = pd.concat([encoding_col, metrics_df], axis=1)

    st.dataframe(
        display_metrics.style.highlight_max(axis=0, subset=metrics_df.columns, color="#c6efce"),
        use_container_width=True,
    )

    st.markdown("---")
    st.subheader("ROC Curves")
    roc_entries = [(name, all_models[name], model_test_data[name]) for name in all_models]
    st.plotly_chart(plot_roc_curves(roc_entries, y_test), use_container_width=True)

    st.markdown("---")
    st.subheader("Confusion Matrices")
    cm_cols = st.columns(4)
    for i, name in enumerate(all_models):
        with cm_cols[i]:
            y_pred = all_models[name].predict(model_test_data[name])
            st.plotly_chart(
                plot_confusion_matrix(y_test, y_pred, title=name),
                use_container_width=True,
            )

# ── Tab 2: Global SHAP ──────────────────────────────────────────────────────
with tab_shap_global:
    st.subheader("Global Feature Importance β€” XGBoost")
    st.markdown("SHAP values show how much each feature pushes the prediction toward or away from churn.")

    xgb_model = all_models["XGBoost"]
    explainer, shap_values = get_shap_explainer(xgb_model, X_train)

    fig_bar, ax_bar = plt.subplots(figsize=(10, 6))
    shap.plots.bar(shap_values, max_display=15, show=False, ax=ax_bar)
    st.pyplot(fig_bar)

    st.markdown("---")
    st.markdown("**Beeswarm Plot** β€” Each dot is a customer. Color = feature value (red = high, blue = low).")

    fig_bee, ax_bee = plt.subplots(figsize=(10, 8))
    shap.plots.beeswarm(shap_values, max_display=15, show=False)
    st.pyplot(plt.gcf())

# ── Tab 3: Individual Explanations ───────────────────────────────────────────
with tab_shap_individual:
    st.subheader("Explain a Single Customer's Prediction")

    raw_df = load_raw_data()
    customer_ids = raw_df["customerID"].values
    selected_id = st.selectbox("Select Customer ID", customer_ids[:200])
    idx_in_raw = raw_df[raw_df["customerID"] == selected_id].index[0]

    enc_df, _ = get_encoded_data()

    if idx_in_raw in X_test.index:
        row = X_test.loc[[idx_in_raw]]
        actual = y_test.loc[idx_in_raw]
    elif idx_in_raw in X_train.index:
        row = X_train.loc[[idx_in_raw]]
        actual = y_train.loc[idx_in_raw]
    else:
        row = enc_df.loc[[idx_in_raw], feature_cols]
        actual = enc_df.loc[idx_in_raw, "Churn"]

    proba = xgb_model.predict_proba(row)[0][1]

    c1, c2 = st.columns(2)
    c1.metric("Predicted Churn Probability", f"{proba:.1%}")
    c2.metric("Actual Outcome", "Churned" if actual == 1 else "Retained")

    st.markdown("**Customer Details (raw values):**")
    st.dataframe(raw_df[raw_df["customerID"] == selected_id].T, use_container_width=True)

    st.markdown("---")
    st.markdown("**SHAP Waterfall β€” Why this prediction?**")
    sv = get_shap_single(explainer, row)
    fig_wf, ax_wf = plt.subplots(figsize=(10, 6))
    shap.plots.waterfall(sv[0], max_display=12, show=False)
    st.pyplot(plt.gcf())

# ── Tab 4: What-If Predictor ─────────────────────────────────────────────────
with tab_whatif:
    st.subheader("What-If Predictor")
    st.markdown("Adjust customer features and see how churn probability changes in real time.")

    raw_df = load_raw_data()

    wi_col1, wi_col2, wi_col3 = st.columns(3)

    with wi_col1:
        wi_gender = st.selectbox("Gender", ["Female", "Male"], key="wi_gender")
        wi_senior = st.selectbox("Senior Citizen", [0, 1], key="wi_senior")
        wi_partner = st.selectbox("Partner", ["Yes", "No"], key="wi_partner")
        wi_dependents = st.selectbox("Dependents", ["Yes", "No"], key="wi_dep")
        wi_tenure = st.slider("Tenure (months)", 0, 72, 12, key="wi_tenure")

    with wi_col2:
        wi_phone = st.selectbox("Phone Service", ["Yes", "No"], key="wi_phone")
        wi_multi = st.selectbox("Multiple Lines", ["No", "Yes", "No phone service"], key="wi_multi")
        wi_internet = st.selectbox("Internet Service", ["DSL", "Fiber optic", "No"], key="wi_inet")
        wi_security = st.selectbox("Online Security", ["Yes", "No", "No internet service"], key="wi_sec")
        wi_backup = st.selectbox("Online Backup", ["Yes", "No", "No internet service"], key="wi_bak")

    with wi_col3:
        wi_protection = st.selectbox("Device Protection", ["Yes", "No", "No internet service"], key="wi_prot")
        wi_support = st.selectbox("Tech Support", ["Yes", "No", "No internet service"], key="wi_sup")
        wi_tv = st.selectbox("Streaming TV", ["Yes", "No", "No internet service"], key="wi_tv")
        wi_movies = st.selectbox("Streaming Movies", ["Yes", "No", "No internet service"], key="wi_mov")
        wi_contract = st.selectbox("Contract", ["Month-to-month", "One year", "Two year"], key="wi_con")

    wi_col4, wi_col5 = st.columns(2)
    with wi_col4:
        wi_paperless = st.selectbox("Paperless Billing", ["Yes", "No"], key="wi_paper")
        wi_payment = st.selectbox(
            "Payment Method",
            ["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"],
            key="wi_pay",
        )
    with wi_col5:
        wi_monthly = st.slider("Monthly Charges ($)", 18.0, 120.0, 70.0, step=0.5, key="wi_monthly")
        wi_total = st.slider("Total Charges ($)", 18.0, 9000.0, 1500.0, step=10.0, key="wi_total")

    input_dict = {
        "gender": wi_gender, "Partner": wi_partner, "Dependents": wi_dependents,
        "PhoneService": wi_phone, "MultipleLines": wi_multi, "InternetService": wi_internet,
        "OnlineSecurity": wi_security, "OnlineBackup": wi_backup,
        "DeviceProtection": wi_protection, "TechSupport": wi_support,
        "StreamingTV": wi_tv, "StreamingMovies": wi_movies, "Contract": wi_contract,
        "PaperlessBilling": wi_paperless, "PaymentMethod": wi_payment,
    }
    numeric_dict = {
        "SeniorCitizen": wi_senior, "tenure": wi_tenure,
        "MonthlyCharges": wi_monthly, "TotalCharges": wi_total,
    }

    _, enc_map = get_encoded_data()
    encoded_input = {}
    for col, val in input_dict.items():
        le = enc_map[col]
        if val in le.classes_:
            encoded_input[col] = le.transform([val])[0]
        else:
            encoded_input[col] = 0

    encoded_input.update(numeric_dict)
    input_row = pd.DataFrame([encoded_input])[feature_cols]

    wi_proba = xgb_model.predict_proba(input_row)[0][1]

    st.markdown("---")
    res_col1, res_col2 = st.columns([1, 2])
    with res_col1:
        st.metric("Churn Probability", f"{wi_proba:.1%}")
        if wi_proba < 0.3:
            st.success("Low risk β€” customer likely to stay")
        elif wi_proba < 0.6:
            st.warning("Medium risk β€” consider retention offer")
        else:
            st.error("High risk β€” immediate intervention recommended")
    with res_col2:
        st.plotly_chart(plot_gauge(wi_proba), use_container_width=True)

    st.markdown("---")
    st.markdown("**Top Feature Drivers for This Configuration:**")
    sv_wi = get_shap_single(explainer, input_row)
    fig_wi, ax_wi = plt.subplots(figsize=(10, 5))
    shap.plots.waterfall(sv_wi[0], max_display=10, show=False)
    st.pyplot(plt.gcf())