Spaces:
Sleeping
Sleeping
| import shap | |
| import numpy as np | |
| import collections | |
| from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | |
| from sklearn.linear_model import LinearRegression, LogisticRegression | |
| class ExplainabilityEngine: | |
| def explain_tabular(self, model_pipeline, X_sample): | |
| """ | |
| Computes global feature importance for tabular models using SHAP. | |
| Aggregates SHAP values from transformed features back to original features. | |
| """ | |
| # -------------------- Validation -------------------- | |
| if X_sample is None or X_sample.empty: | |
| raise ValueError("Sample data is empty, cannot compute explanations") | |
| if "preprocessor" not in model_pipeline.named_steps: | |
| raise ValueError("Pipeline missing 'preprocessor' step") | |
| if "model" not in model_pipeline.named_steps: | |
| raise ValueError("Pipeline missing 'model' step") | |
| preprocessor = model_pipeline.named_steps["preprocessor"] | |
| model = model_pipeline.named_steps["model"] | |
| # -------------------- Transform Data -------------------- | |
| X_transformed = preprocessor.transform(X_sample) | |
| if X_transformed is None or X_transformed.shape[0] == 0: | |
| raise ValueError("Transformed data is empty after preprocessing") | |
| # -------------------- SHAP Explainer Selection -------------------- | |
| if isinstance(model, (RandomForestClassifier, RandomForestRegressor)): | |
| explainer = shap.TreeExplainer(model) | |
| shap_values = explainer.shap_values( | |
| X_transformed, check_additivity=False | |
| ) | |
| # For binary classification, SHAP returns list → take positive class | |
| if isinstance(shap_values, list): | |
| shap_values = shap_values[1] | |
| elif isinstance(model, (LinearRegression, LogisticRegression)): | |
| explainer = shap.LinearExplainer(model, X_transformed) | |
| shap_values = explainer.shap_values(X_transformed) | |
| else: | |
| raise ValueError( | |
| f"Explainability not supported for model type: {type(model)}" | |
| ) | |
| # -------------------- Validate SHAP Output -------------------- | |
| if shap_values is None or len(shap_values) == 0: | |
| raise ValueError("SHAP computation failed") | |
| if isinstance(shap_values, list): | |
| shap_values = np.array(shap_values) | |
| # -------------------- Aggregate Importance -------------------- | |
| abs_shap = np.abs(shap_values) | |
| mean_shap = abs_shap.mean(axis=0) | |
| try: | |
| feature_names = preprocessor.get_feature_names_out() | |
| except Exception as e: | |
| raise ValueError( | |
| f"Failed to retrieve feature names from preprocessor: {e}" | |
| ) | |
| if len(feature_names) != len(mean_shap): | |
| raise ValueError( | |
| "Mismatch between SHAP values and feature names" | |
| ) | |
| # Aggregate encoded features back to original feature names | |
| aggregated_importance = collections.defaultdict(float) | |
| for feature_name, importance in zip(feature_names, mean_shap): | |
| # Examples: | |
| # num__WindSpeed | |
| # cat__PaymentMethod_CreditCard | |
| if "__" in feature_name: | |
| original_feature = feature_name.split("__")[1].split("_")[0] | |
| else: | |
| original_feature = feature_name | |
| aggregated_importance[original_feature] += float(importance) | |
| if not aggregated_importance: | |
| raise ValueError("No feature importance computed after aggregation") | |
| # -------------------- Sort + Limit Output -------------------- | |
| sorted_importance = dict( | |
| sorted( | |
| aggregated_importance.items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| )[:10] | |
| ) | |
| return sorted_importance | |