File size: 3,175 Bytes
8ea1e26
 
4da4fcb
8ea1e26
4da4fcb
8ea1e26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
caa7bf6
a17a714
 
caa7bf6
 
8ea1e26
 
 
 
 
 
 
 
 
 
 
 
caa7bf6
8ea1e26
 
 
 
4da4fcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import pandas as pd
import streamlit as st
import shap
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_score, recall_score, brier_score_loss, log_loss
import streamlit.components.v1 as components

def compute_metrics(y_true, y_pred_proba, threshold=0.5):
    y_pred = (y_pred_proba >= threshold).astype(int)
    return {
        "AUC": roc_auc_score(y_true, y_pred_proba),
        "F1": f1_score(y_true, y_pred),
        "Accuracy": accuracy_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred),
        "Recall": recall_score(y_true, y_pred),
        "BrierScore": brier_score_loss(y_true, y_pred_proba),
        "Logloss": log_loss(y_true, y_pred_proba),
    }

def add_predictions(df, probs):
    df['Predicted Probability'] = probs
    pred_col = f"{st.session_state.target_col} Prediction"
    threshold = 0.5
    df[pred_col] = ['POSITIVE' if p >= threshold else 'NEGATIVE' for p in probs]

    df_with_gt = df[['Predicted Probability', pred_col]].join(st.session_state.targets_df)

    # Define cell-level styling
    def highlight_prediction(val):
        if val == "POSITIVE":
            return "background-color: #d4edda; color: #155724; text-align: center;"
        elif val == "NEGATIVE":
            return "background-color: #f8d7da; color: #721c24; text-align: center;"
        return "text-align: center;"

    # Apply color and alignment
    df_styled = (
        df_with_gt.style
        .applymap(highlight_prediction, subset=[pred_col])
        .set_properties(**{'text-align': 'center'})  # Apply center alignment to all cells
    )

    return df_styled

def st_shap(plot, height=None):
    shap_html = f"<head>{shap.getjs()}</head><body>{plot.html()}</body>"
    components.html(shap_html, height=height)

def ensemble_shap(models, X, model_weights=None):
    """
    Compute ensemble SHAP values for a list of tree-based models.
    Returns a shap.Explanation with mean SHAP values across models.
    """
    import numpy as np
    import shap

    all_values = []
    all_base_values = []

    for model in models:
        explainer = shap.TreeExplainer(model)
        shap_values = explainer(X)

        # Handle binary classification
        if shap_values.values.ndim == 3:
            # safer class selection
            class_index = getattr(model, "classes_", [0, 1]).index(1)
            shap_values = shap.Explanation(
                values=shap_values.values[:, :, class_index],
                base_values=shap_values.base_values[:, class_index],
                data=X,
                feature_names=X.columns
            )

        all_values.append(shap_values.values)
        all_base_values.append(shap_values.base_values)

    # Handle weights
    if model_weights is None:
        model_weights = np.ones(len(models))
    model_weights = np.array(model_weights) / np.sum(model_weights)

    mean_values = np.average(all_values, axis=0, weights=model_weights)
    mean_base = np.average(all_base_values, axis=0, weights=model_weights)

    return shap.Explanation(
        values=mean_values,
        base_values=mean_base,
        data=X,
        feature_names=X.columns
    )