Spaces:
Sleeping
Sleeping
| """ | |
| Explainability Module - InsightGenAI | |
| ==================================== | |
| SHAP-based model explainability with feature importance plots, | |
| summary plots, and individual prediction explanations. | |
| Author: InsightGenAI Team | |
| Version: 1.0.0 | |
| """ | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| from typing import Dict, List, Tuple, Optional, Any, Union | |
| import streamlit as st | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Try to import shap, handle if not available | |
| try: | |
| import shap | |
| SHAP_AVAILABLE = True | |
| except ImportError: | |
| SHAP_AVAILABLE = False | |
| class ExplainabilityEngine: | |
| """ | |
| Model explainability engine using SHAP values. | |
| Attributes: | |
| model: Trained model to explain | |
| X: Feature matrix | |
| explainer: SHAP explainer object | |
| shap_values: Calculated SHAP values | |
| """ | |
| def __init__(self, model, X: pd.DataFrame, feature_names: Optional[List[str]] = None): | |
| """ | |
| Initialize the Explainability Engine. | |
| Args: | |
| model: Trained model | |
| X: Feature data (sample for background) | |
| feature_names: List of feature names | |
| """ | |
| if not SHAP_AVAILABLE: | |
| raise ImportError("SHAP is not installed. Please install with: pip install shap") | |
| self.model = model | |
| self.X = X.copy() if isinstance(X, pd.DataFrame) else pd.DataFrame(X) | |
| self.feature_names = feature_names or self.X.columns.tolist() | |
| self.X.columns = self.feature_names | |
| self.explainer = None | |
| self.shap_values = None | |
| self.expected_value = None | |
| # Initialize SHAP explainer | |
| self._init_explainer() | |
| def _init_explainer(self) -> None: | |
| """Initialize the appropriate SHAP explainer for the model.""" | |
| try: | |
| # Try TreeExplainer first (for tree-based models) | |
| self.explainer = shap.TreeExplainer(self.model) | |
| self.shap_values = self.explainer.shap_values(self.X) | |
| self.expected_value = self.explainer.expected_value | |
| except Exception: | |
| try: | |
| # Fall back to KernelExplainer | |
| self.explainer = shap.KernelExplainer(self.model.predict, shap.sample(self.X, 100)) | |
| self.shap_values = self.explainer.shap_values(self.X) | |
| self.expected_value = self.explainer.expected_value | |
| except Exception as e: | |
| raise RuntimeError(f"Could not initialize SHAP explainer: {str(e)}") | |
| def get_feature_importance(self) -> pd.DataFrame: | |
| """ | |
| Get global feature importance based on mean absolute SHAP values. | |
| Returns: | |
| pd.DataFrame with feature importance | |
| """ | |
| if self.shap_values is None: | |
| raise ValueError("SHAP values not calculated. Please initialize explainer first.") | |
| # Handle different shap_values formats | |
| if isinstance(self.shap_values, list): | |
| # For multi-class, use the mean across all classes | |
| shap_array = np.abs(np.array(self.shap_values)).mean(axis=0).mean(axis=0) | |
| else: | |
| shap_array = np.abs(self.shap_values).mean(axis=0) | |
| importance_df = pd.DataFrame({ | |
| 'feature': self.feature_names, | |
| 'importance': shap_array | |
| }).sort_values('importance', ascending=False) | |
| return importance_df | |
| def plot_summary(self, max_display: int = 15, figsize: Tuple[int, int] = (10, 8)) -> plt.Figure: | |
| """ | |
| Create SHAP summary plot (beeswarm plot). | |
| Args: | |
| max_display: Maximum number of features to display | |
| figsize: Figure size tuple | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Handle different shap_values formats | |
| if isinstance(self.shap_values, list): | |
| # For multi-class classification, use the first class | |
| shap_values_plot = self.shap_values[0] | |
| else: | |
| shap_values_plot = self.shap_values | |
| shap.summary_plot( | |
| shap_values_plot, | |
| self.X, | |
| feature_names=self.feature_names, | |
| max_display=max_display, | |
| show=False | |
| ) | |
| plt.title('SHAP Summary Plot', fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return fig | |
| def plot_feature_importance(self, max_display: int = 15, | |
| figsize: Tuple[int, int] = (10, 8)) -> plt.Figure: | |
| """ | |
| Create bar plot of feature importance. | |
| Args: | |
| max_display: Maximum number of features to display | |
| figsize: Figure size tuple | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Handle different shap_values formats | |
| if isinstance(self.shap_values, list): | |
| shap_values_plot = self.shap_values[0] | |
| else: | |
| shap_values_plot = self.shap_values | |
| shap.summary_plot( | |
| shap_values_plot, | |
| self.X, | |
| feature_names=self.feature_names, | |
| max_display=max_display, | |
| plot_type='bar', | |
| show=False | |
| ) | |
| plt.title('SHAP Feature Importance', fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return fig | |
| def plot_waterfall(self, instance_idx: int = 0, | |
| max_display: int = 10, | |
| figsize: Tuple[int, int] = (12, 6)) -> plt.Figure: | |
| """ | |
| Create waterfall plot for a single prediction. | |
| Args: | |
| instance_idx: Index of the instance to explain | |
| max_display: Maximum number of features to display | |
| figsize: Figure size tuple | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Handle different shap_values formats | |
| if isinstance(self.shap_values, list): | |
| shap_values_plot = self.shap_values[0] | |
| expected_value = self.expected_value[0] if isinstance(self.expected_value, (list, np.ndarray)) else self.expected_value | |
| else: | |
| shap_values_plot = self.shap_values | |
| expected_value = self.expected_value | |
| shap.waterfall_plot( | |
| shap.Explanation( | |
| values=shap_values_plot[instance_idx], | |
| base_values=expected_value, | |
| data=self.X.iloc[instance_idx].values, | |
| feature_names=self.feature_names | |
| ), | |
| max_display=max_display, | |
| show=False | |
| ) | |
| plt.title(f'SHAP Waterfall Plot - Instance {instance_idx}', fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return fig | |
| def plot_dependence(self, feature: str, | |
| interaction_feature: Optional[str] = None, | |
| figsize: Tuple[int, int] = (10, 6)) -> plt.Figure: | |
| """ | |
| Create dependence plot for a feature. | |
| Args: | |
| feature: Feature name to plot | |
| interaction_feature: Feature to use for coloring | |
| figsize: Figure size tuple | |
| Returns: | |
| matplotlib Figure object | |
| """ | |
| fig, ax = plt.subplots(figsize=figsize) | |
| # Handle different shap_values formats | |
| if isinstance(self.shap_values, list): | |
| shap_values_plot = self.shap_values[0] | |
| else: | |
| shap_values_plot = self.shap_values | |
| feature_idx = self.feature_names.index(feature) if feature in self.feature_names else None | |
| if feature_idx is not None: | |
| shap.dependence_plot( | |
| feature_idx, | |
| shap_values_plot, | |
| self.X, | |
| feature_names=self.feature_names, | |
| interaction_index=interaction_feature, | |
| show=False, | |
| ax=ax | |
| ) | |
| plt.title(f'SHAP Dependence Plot: {feature}', fontsize=14, fontweight='bold', pad=20) | |
| plt.tight_layout() | |
| return fig | |
| def explain_instance(self, instance_idx: int) -> Dict: | |
| """ | |
| Get explanation for a single instance. | |
| Args: | |
| instance_idx: Index of the instance | |
| Returns: | |
| Dict with explanation details | |
| """ | |
| if isinstance(self.shap_values, list): | |
| shap_values = self.shap_values[0][instance_idx] | |
| expected_value = self.expected_value[0] if isinstance(self.expected_value, (list, np.ndarray)) else self.expected_value | |
| else: | |
| shap_values = self.shap_values[instance_idx] | |
| expected_value = self.expected_value | |
| # Create feature contribution dataframe | |
| contributions = pd.DataFrame({ | |
| 'feature': self.feature_names, | |
| 'value': self.X.iloc[instance_idx].values, | |
| 'shap_value': shap_values, | |
| 'abs_shap_value': np.abs(shap_values) | |
| }).sort_values('abs_shap_value', ascending=False) | |
| prediction = expected_value + shap_values.sum() | |
| return { | |
| 'instance_index': instance_idx, | |
| 'expected_value': expected_value, | |
| 'prediction': prediction, | |
| 'contributions': contributions.to_dict('records') | |
| } | |
| def get_global_explanations(self) -> Dict: | |
| """ | |
| Get global explanations for the model. | |
| Returns: | |
| Dict with global explanation metrics | |
| """ | |
| importance_df = self.get_feature_importance() | |
| return { | |
| 'top_features': importance_df.head(10).to_dict('records'), | |
| 'feature_count': len(self.feature_names), | |
| 'mean_shap_value': np.abs(self.shap_values).mean() if not isinstance(self.shap_values, list) else np.abs(np.array(self.shap_values)).mean() | |
| } | |
| class FallbackExplainability: | |
| """ | |
| Fallback explainability engine when SHAP is not available. | |
| Uses built-in feature importance from models. | |
| """ | |
| def __init__(self, model, X: pd.DataFrame, feature_names: Optional[List[str]] = None): | |
| """ | |
| Initialize fallback explainability. | |
| Args: | |
| model: Trained model | |
| X: Feature data | |
| feature_names: List of feature names | |
| """ | |
| self.model = model | |
| self.X = X.copy() if isinstance(X, pd.DataFrame) else pd.DataFrame(X) | |
| self.feature_names = feature_names or self.X.columns.tolist() | |
| def get_feature_importance(self) -> pd.DataFrame: | |
| """Get feature importance from model.""" | |
| if hasattr(self.model, 'feature_importances_'): | |
| importance = self.model.feature_importances_ | |
| elif hasattr(self.model, 'coef_'): | |
| importance = np.abs(self.model.coef_) | |
| if importance.ndim > 1: | |
| importance = importance.mean(axis=0) | |
| else: | |
| # Use permutation importance as fallback | |
| from sklearn.inspection import permutation_importance | |
| perm_importance = permutation_importance(self.model, self.X, | |
| np.zeros(len(self.X)), | |
| n_repeats=5, random_state=42) | |
| importance = perm_importance.importances_mean | |
| importance_df = pd.DataFrame({ | |
| 'feature': self.feature_names, | |
| 'importance': importance | |
| }).sort_values('importance', ascending=False) | |
| return importance_df | |
| def plot_feature_importance(self, max_display: int = 15, | |
| figsize: Tuple[int, int] = (10, 8)) -> plt.Figure: | |
| """Create bar plot of feature importance.""" | |
| importance_df = self.get_feature_importance().head(max_display) | |
| fig, ax = plt.subplots(figsize=figsize) | |
| sns.barplot(data=importance_df, y='feature', x='importance', ax=ax, palette='viridis') | |
| ax.set_title('Feature Importance (Model Built-in)', fontsize=14, fontweight='bold') | |
| ax.set_xlabel('Importance') | |
| ax.set_ylabel('Feature') | |
| plt.tight_layout() | |
| return fig | |
| def create_explainer(model, X: pd.DataFrame, feature_names: Optional[List[str]] = None): | |
| """ | |
| Factory function to create appropriate explainer. | |
| Args: | |
| model: Trained model | |
| X: Feature data | |
| feature_names: List of feature names | |
| Returns: | |
| ExplainabilityEngine or FallbackExplainability instance | |
| """ | |
| if SHAP_AVAILABLE: | |
| try: | |
| return ExplainabilityEngine(model, X, feature_names) | |
| except Exception as e: | |
| st.warning(f"SHAP explainer failed, using fallback: {str(e)}") | |
| return FallbackExplainability(model, X, feature_names) | |
| else: | |
| return FallbackExplainability(model, X, feature_names) | |
| # Streamlit display functions | |
| def display_shap_explanations(explainer, X_sample: pd.DataFrame = None): | |
| """Display SHAP explanations in Streamlit.""" | |
| st.subheader("🔍 Model Explainability") | |
| if not SHAP_AVAILABLE: | |
| st.warning("SHAP is not installed. Using built-in feature importance instead.") | |
| # Feature importance | |
| st.write("### Feature Importance") | |
| fig_importance = explainer.plot_feature_importance() | |
| st.pyplot(fig_importance) | |
| # Summary plot (only for SHAP) | |
| if isinstance(explainer, ExplainabilityEngine): | |
| st.write("### SHAP Summary Plot") | |
| try: | |
| fig_summary = explainer.plot_summary() | |
| st.pyplot(fig_summary) | |
| except Exception as e: | |
| st.warning(f"Could not generate summary plot: {str(e)}") | |
| # Waterfall plot for first instance | |
| if X_sample is not None and len(X_sample) > 0: | |
| st.write("### Individual Prediction Explanation") | |
| try: | |
| fig_waterfall = explainer.plot_waterfall(instance_idx=0) | |
| st.pyplot(fig_waterfall) | |
| except Exception as e: | |
| st.warning(f"Could not generate waterfall plot: {str(e)}") | |