Spaces:
Sleeping
Sleeping
File size: 3,930 Bytes
6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 6f5e4ab 654bcd9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import shap
import numpy as np
import collections
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
class ExplainabilityEngine:
def explain_tabular(self, model_pipeline, X_sample):
"""
Computes global feature importance for tabular models using SHAP.
Aggregates SHAP values from transformed features back to original features.
"""
# -------------------- Validation --------------------
if X_sample is None or X_sample.empty:
raise ValueError("Sample data is empty, cannot compute explanations")
if "preprocessor" not in model_pipeline.named_steps:
raise ValueError("Pipeline missing 'preprocessor' step")
if "model" not in model_pipeline.named_steps:
raise ValueError("Pipeline missing 'model' step")
preprocessor = model_pipeline.named_steps["preprocessor"]
model = model_pipeline.named_steps["model"]
# -------------------- Transform Data --------------------
X_transformed = preprocessor.transform(X_sample)
if X_transformed is None or X_transformed.shape[0] == 0:
raise ValueError("Transformed data is empty after preprocessing")
# -------------------- SHAP Explainer Selection --------------------
if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(
X_transformed, check_additivity=False
)
# For binary classification, SHAP returns list → take positive class
if isinstance(shap_values, list):
shap_values = shap_values[1]
elif isinstance(model, (LinearRegression, LogisticRegression)):
explainer = shap.LinearExplainer(model, X_transformed)
shap_values = explainer.shap_values(X_transformed)
else:
raise ValueError(
f"Explainability not supported for model type: {type(model)}"
)
# -------------------- Validate SHAP Output --------------------
if shap_values is None or len(shap_values) == 0:
raise ValueError("SHAP computation failed")
if isinstance(shap_values, list):
shap_values = np.array(shap_values)
# -------------------- Aggregate Importance --------------------
abs_shap = np.abs(shap_values)
mean_shap = abs_shap.mean(axis=0)
try:
feature_names = preprocessor.get_feature_names_out()
except Exception as e:
raise ValueError(
f"Failed to retrieve feature names from preprocessor: {e}"
)
if len(feature_names) != len(mean_shap):
raise ValueError(
"Mismatch between SHAP values and feature names"
)
# Aggregate encoded features back to original feature names
aggregated_importance = collections.defaultdict(float)
for feature_name, importance in zip(feature_names, mean_shap):
# Examples:
# num__WindSpeed
# cat__PaymentMethod_CreditCard
if "__" in feature_name:
original_feature = feature_name.split("__")[1].split("_")[0]
else:
original_feature = feature_name
aggregated_importance[original_feature] += float(importance)
if not aggregated_importance:
raise ValueError("No feature importance computed after aggregation")
# -------------------- Sort + Limit Output --------------------
sorted_importance = dict(
sorted(
aggregated_importance.items(),
key=lambda x: x[1],
reverse=True
)[:10]
)
return sorted_importance
|