Spaces:
No application file
No application file
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from sklearn.inspection import permutation_importance | |
| import matplotlib.pyplot as plt | |
| import shap | |
| def explain_model(model, X, y=None, feature_names=None): | |
| """ | |
| Explain model predictions using various techniques | |
| """ | |
| st.subheader("π Model Explainability") | |
| if feature_names is None: | |
| feature_names = X.columns if hasattr(X, 'columns') else [f"Feature {i}" for i in range(X.shape[1])] | |
| # Create tabs for different explanation methods | |
| tab1, tab2, tab3 = st.tabs(["Feature Importance", "SHAP Values", "Partial Dependence"]) | |
| with tab1: | |
| st.markdown("### π Feature Importance") | |
| # Method selection | |
| method = st.radio( | |
| "Importance method", | |
| ["Built-in", "Permutation"], | |
| horizontal=True | |
| ) | |
| if method == "Built-in": | |
| if hasattr(model, 'feature_importances_'): | |
| importance = model.feature_importances_ | |
| importance_df = pd.DataFrame({ | |
| 'feature': feature_names, | |
| 'importance': importance | |
| }).sort_values('importance', ascending=False) | |
| fig = px.bar(importance_df.head(20), x='importance', y='feature', | |
| orientation='h', title="Feature Importance (Built-in)") | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.warning("Model doesn't have built-in feature importance") | |
| else: # Permutation importance | |
| if y is not None: | |
| with st.spinner("Calculating permutation importance..."): | |
| perm_importance = permutation_importance(model, X, y, n_repeats=10) | |
| importance_df = pd.DataFrame({ | |
| 'feature': feature_names, | |
| 'importance': perm_importance.importances_mean, | |
| 'std': perm_importance.importances_std | |
| }).sort_values('importance', ascending=False) | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| x=importance_df['importance'].head(20), | |
| y=importance_df['feature'].head(20), | |
| orientation='h', | |
| error_x=dict( | |
| type='data', | |
| array=importance_df['std'].head(20), | |
| visible=True | |
| ) | |
| )) | |
| fig.update_layout(title="Permutation Importance (with error bars)", | |
| xaxis_title="Importance") | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.warning("Need target values for permutation importance") | |
| with tab2: | |
| st.markdown("### π SHAP Values") | |
| if hasattr(model, 'predict'): | |
| with st.spinner("Calculating SHAP values (this may take a moment)..."): | |
| try: | |
| # Create explainer based on model type | |
| if str(type(model)).find('sklearn') != -1: | |
| explainer = shap.Explainer(model, X[:100]) # Use subset for speed | |
| else: | |
| explainer = shap.TreeExplainer(model) if hasattr(model, 'feature_importances_') else shap.Explainer(model, X[:100]) | |
| # Calculate SHAP values | |
| shap_values = explainer(X[:100]) # Limit to 100 samples for performance | |
| # Summary plot | |
| st.markdown("#### SHAP Summary Plot") | |
| fig, ax = plt.subplots() | |
| shap.summary_plot(shap_values, X[:100], feature_names=feature_names, show=False) | |
| st.pyplot(fig) | |
| plt.close() | |
| # Waterfall plot for a single prediction | |
| st.markdown("#### Single Prediction Explanation") | |
| sample_idx = st.slider("Select sample index", 0, min(99, len(X)-1), 0) | |
| fig, ax = plt.subplots() | |
| shap.waterfall_plot(shap_values[sample_idx], show=False) | |
| st.pyplot(fig) | |
| plt.close() | |
| except Exception as e: | |
| st.error(f"Error calculating SHAP values: {str(e)}") | |
| st.info("Try using a smaller sample or a different model type") | |
| else: | |
| st.warning("Model doesn't support prediction") | |
| with tab3: | |
| st.markdown("### π Partial Dependence Plots") | |
| if hasattr(model, 'predict') and len(feature_names) > 0: | |
| from sklearn.inspection import partial_dependence | |
| selected_feature = st.selectbox("Select feature for PDP", feature_names) | |
| if selected_feature: | |
| feature_idx = list(feature_names).index(selected_feature) | |
| # Calculate partial dependence | |
| pdp = partial_dependence(model, X, [feature_idx], grid_resolution=50) | |
| # Create plot | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=pdp['values'][0], | |
| y=pdp['average'][0], | |
| mode='lines+markers', | |
| name='Partial Dependence' | |
| )) | |
| fig.update_layout( | |
| title=f"Partial Dependence Plot for {selected_feature}", | |
| xaxis_title=selected_feature, | |
| yaxis_title="Prediction" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Individual conditional expectation (ICE) plots | |
| if st.checkbox("Show ICE plots"): | |
| ice_data = [] | |
| for i in range(min(10, X.shape[0])): # Show up to 10 lines | |
| ice = partial_dependence(model, X.iloc[i:i+1], [feature_idx], grid_resolution=20) | |
| ice_data.append(ice['average'][0]) | |
| fig = go.Figure() | |
| for i, ice in enumerate(ice_data): | |
| fig.add_trace(go.Scatter( | |
| x=pdp['values'][0], | |
| y=ice, | |
| mode='lines', | |
| name=f'Sample {i}', | |
| line=dict(width=1, color='lightgray') | |
| )) | |
| # Add average line | |
| fig.add_trace(go.Scatter( | |
| x=pdp['values'][0], | |
| y=pdp['average'][0], | |
| mode='lines', | |
| name='Average', | |
| line=dict(width=3, color='red') | |
| )) | |
| fig.update_layout( | |
| title=f"ICE Plots for {selected_feature}", | |
| xaxis_title=selected_feature, | |
| yaxis_title="Prediction" | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.warning("Need more features for partial dependence plots") |