File size: 7,836 Bytes
da8e446
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import streamlit as st
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.inspection import permutation_importance
import matplotlib.pyplot as plt
import shap

def explain_model(model, X, y=None, feature_names=None):
    """

    Explain model predictions using various techniques

    """
    st.subheader("πŸ” Model Explainability")
    
    if feature_names is None:
        feature_names = X.columns if hasattr(X, 'columns') else [f"Feature {i}" for i in range(X.shape[1])]
    
    # Create tabs for different explanation methods
    tab1, tab2, tab3 = st.tabs(["Feature Importance", "SHAP Values", "Partial Dependence"])
    
    with tab1:
        st.markdown("### πŸ“Š Feature Importance")
        
        # Method selection
        method = st.radio(
            "Importance method",
            ["Built-in", "Permutation"],
            horizontal=True
        )
        
        if method == "Built-in":
            if hasattr(model, 'feature_importances_'):
                importance = model.feature_importances_
                importance_df = pd.DataFrame({
                    'feature': feature_names,
                    'importance': importance
                }).sort_values('importance', ascending=False)
                
                fig = px.bar(importance_df.head(20), x='importance', y='feature',
                           orientation='h', title="Feature Importance (Built-in)")
                st.plotly_chart(fig, use_container_width=True)
            else:
                st.warning("Model doesn't have built-in feature importance")
        
        else:  # Permutation importance
            if y is not None:
                with st.spinner("Calculating permutation importance..."):
                    perm_importance = permutation_importance(model, X, y, n_repeats=10)
                    
                    importance_df = pd.DataFrame({
                        'feature': feature_names,
                        'importance': perm_importance.importances_mean,
                        'std': perm_importance.importances_std
                    }).sort_values('importance', ascending=False)
                    
                    fig = go.Figure()
                    fig.add_trace(go.Bar(
                        x=importance_df['importance'].head(20),
                        y=importance_df['feature'].head(20),
                        orientation='h',
                        error_x=dict(
                            type='data',
                            array=importance_df['std'].head(20),
                            visible=True
                        )
                    ))
                    fig.update_layout(title="Permutation Importance (with error bars)",
                                    xaxis_title="Importance")
                    st.plotly_chart(fig, use_container_width=True)
            else:
                st.warning("Need target values for permutation importance")
    
    with tab2:
        st.markdown("### πŸ“ˆ SHAP Values")
        
        if hasattr(model, 'predict'):
            with st.spinner("Calculating SHAP values (this may take a moment)..."):
                try:
                    # Create explainer based on model type
                    if str(type(model)).find('sklearn') != -1:
                        explainer = shap.Explainer(model, X[:100])  # Use subset for speed
                    else:
                        explainer = shap.TreeExplainer(model) if hasattr(model, 'feature_importances_') else shap.Explainer(model, X[:100])
                    
                    # Calculate SHAP values
                    shap_values = explainer(X[:100])  # Limit to 100 samples for performance
                    
                    # Summary plot
                    st.markdown("#### SHAP Summary Plot")
                    fig, ax = plt.subplots()
                    shap.summary_plot(shap_values, X[:100], feature_names=feature_names, show=False)
                    st.pyplot(fig)
                    plt.close()
                    
                    # Waterfall plot for a single prediction
                    st.markdown("#### Single Prediction Explanation")
                    sample_idx = st.slider("Select sample index", 0, min(99, len(X)-1), 0)
                    
                    fig, ax = plt.subplots()
                    shap.waterfall_plot(shap_values[sample_idx], show=False)
                    st.pyplot(fig)
                    plt.close()
                    
                except Exception as e:
                    st.error(f"Error calculating SHAP values: {str(e)}")
                    st.info("Try using a smaller sample or a different model type")
        else:
            st.warning("Model doesn't support prediction")
    
    with tab3:
        st.markdown("### πŸ“‰ Partial Dependence Plots")
        
        if hasattr(model, 'predict') and len(feature_names) > 0:
            from sklearn.inspection import partial_dependence
            
            selected_feature = st.selectbox("Select feature for PDP", feature_names)
            
            if selected_feature:
                feature_idx = list(feature_names).index(selected_feature)
                
                # Calculate partial dependence
                pdp = partial_dependence(model, X, [feature_idx], grid_resolution=50)
                
                # Create plot
                fig = go.Figure()
                fig.add_trace(go.Scatter(
                    x=pdp['values'][0],
                    y=pdp['average'][0],
                    mode='lines+markers',
                    name='Partial Dependence'
                ))
                
                fig.update_layout(
                    title=f"Partial Dependence Plot for {selected_feature}",
                    xaxis_title=selected_feature,
                    yaxis_title="Prediction"
                )
                
                st.plotly_chart(fig, use_container_width=True)
                
                # Individual conditional expectation (ICE) plots
                if st.checkbox("Show ICE plots"):
                    ice_data = []
                    for i in range(min(10, X.shape[0])):  # Show up to 10 lines
                        ice = partial_dependence(model, X.iloc[i:i+1], [feature_idx], grid_resolution=20)
                        ice_data.append(ice['average'][0])
                    
                    fig = go.Figure()
                    for i, ice in enumerate(ice_data):
                        fig.add_trace(go.Scatter(
                            x=pdp['values'][0],
                            y=ice,
                            mode='lines',
                            name=f'Sample {i}',
                            line=dict(width=1, color='lightgray')
                        ))
                    
                    # Add average line
                    fig.add_trace(go.Scatter(
                        x=pdp['values'][0],
                        y=pdp['average'][0],
                        mode='lines',
                        name='Average',
                        line=dict(width=3, color='red')
                    ))
                    
                    fig.update_layout(
                        title=f"ICE Plots for {selected_feature}",
                        xaxis_title=selected_feature,
                        yaxis_title="Prediction"
                    )
                    
                    st.plotly_chart(fig, use_container_width=True)
        else:
            st.warning("Need more features for partial dependence plots")