import shap import numpy as np import collections from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from sklearn.linear_model import LinearRegression, LogisticRegression class ExplainabilityEngine: def explain_tabular(self, model_pipeline, X_sample): """ Computes global feature importance for tabular models using SHAP. Aggregates SHAP values from transformed features back to original features. """ # -------------------- Validation -------------------- if X_sample is None or X_sample.empty: raise ValueError("Sample data is empty, cannot compute explanations") if "preprocessor" not in model_pipeline.named_steps: raise ValueError("Pipeline missing 'preprocessor' step") if "model" not in model_pipeline.named_steps: raise ValueError("Pipeline missing 'model' step") preprocessor = model_pipeline.named_steps["preprocessor"] model = model_pipeline.named_steps["model"] # -------------------- Transform Data -------------------- X_transformed = preprocessor.transform(X_sample) if X_transformed is None or X_transformed.shape[0] == 0: raise ValueError("Transformed data is empty after preprocessing") # -------------------- SHAP Explainer Selection -------------------- if isinstance(model, (RandomForestClassifier, RandomForestRegressor)): explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values( X_transformed, check_additivity=False ) # For binary classification, SHAP returns list → take positive class if isinstance(shap_values, list): shap_values = shap_values[1] elif isinstance(model, (LinearRegression, LogisticRegression)): explainer = shap.LinearExplainer(model, X_transformed) shap_values = explainer.shap_values(X_transformed) else: raise ValueError( f"Explainability not supported for model type: {type(model)}" ) # -------------------- Validate SHAP Output -------------------- if shap_values is None or len(shap_values) == 0: raise ValueError("SHAP computation failed") if isinstance(shap_values, list): shap_values = np.array(shap_values) # -------------------- Aggregate Importance -------------------- abs_shap = np.abs(shap_values) mean_shap = abs_shap.mean(axis=0) try: feature_names = preprocessor.get_feature_names_out() except Exception as e: raise ValueError( f"Failed to retrieve feature names from preprocessor: {e}" ) if len(feature_names) != len(mean_shap): raise ValueError( "Mismatch between SHAP values and feature names" ) # Aggregate encoded features back to original feature names aggregated_importance = collections.defaultdict(float) for feature_name, importance in zip(feature_names, mean_shap): # Examples: # num__WindSpeed # cat__PaymentMethod_CreditCard if "__" in feature_name: original_feature = feature_name.split("__")[1].split("_")[0] else: original_feature = feature_name aggregated_importance[original_feature] += float(importance) if not aggregated_importance: raise ValueError("No feature importance computed after aggregation") # -------------------- Sort + Limit Output -------------------- sorted_importance = dict( sorted( aggregated_importance.items(), key=lambda x: x[1], reverse=True )[:10] ) return sorted_importance