ModelSmith-AI / backend /core /explainability.py
ACA050's picture
Update backend/core/explainability.py
654bcd9 verified
import shap
import numpy as np
import collections
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression
class ExplainabilityEngine:
def explain_tabular(self, model_pipeline, X_sample):
"""
Computes global feature importance for tabular models using SHAP.
Aggregates SHAP values from transformed features back to original features.
"""
# -------------------- Validation --------------------
if X_sample is None or X_sample.empty:
raise ValueError("Sample data is empty, cannot compute explanations")
if "preprocessor" not in model_pipeline.named_steps:
raise ValueError("Pipeline missing 'preprocessor' step")
if "model" not in model_pipeline.named_steps:
raise ValueError("Pipeline missing 'model' step")
preprocessor = model_pipeline.named_steps["preprocessor"]
model = model_pipeline.named_steps["model"]
# -------------------- Transform Data --------------------
X_transformed = preprocessor.transform(X_sample)
if X_transformed is None or X_transformed.shape[0] == 0:
raise ValueError("Transformed data is empty after preprocessing")
# -------------------- SHAP Explainer Selection --------------------
if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(
X_transformed, check_additivity=False
)
# For binary classification, SHAP returns list → take positive class
if isinstance(shap_values, list):
shap_values = shap_values[1]
elif isinstance(model, (LinearRegression, LogisticRegression)):
explainer = shap.LinearExplainer(model, X_transformed)
shap_values = explainer.shap_values(X_transformed)
else:
raise ValueError(
f"Explainability not supported for model type: {type(model)}"
)
# -------------------- Validate SHAP Output --------------------
if shap_values is None or len(shap_values) == 0:
raise ValueError("SHAP computation failed")
if isinstance(shap_values, list):
shap_values = np.array(shap_values)
# -------------------- Aggregate Importance --------------------
abs_shap = np.abs(shap_values)
mean_shap = abs_shap.mean(axis=0)
try:
feature_names = preprocessor.get_feature_names_out()
except Exception as e:
raise ValueError(
f"Failed to retrieve feature names from preprocessor: {e}"
)
if len(feature_names) != len(mean_shap):
raise ValueError(
"Mismatch between SHAP values and feature names"
)
# Aggregate encoded features back to original feature names
aggregated_importance = collections.defaultdict(float)
for feature_name, importance in zip(feature_names, mean_shap):
# Examples:
# num__WindSpeed
# cat__PaymentMethod_CreditCard
if "__" in feature_name:
original_feature = feature_name.split("__")[1].split("_")[0]
else:
original_feature = feature_name
aggregated_importance[original_feature] += float(importance)
if not aggregated_importance:
raise ValueError("No feature importance computed after aggregation")
# -------------------- Sort + Limit Output --------------------
sorted_importance = dict(
sorted(
aggregated_importance.items(),
key=lambda x: x[1],
reverse=True
)[:10]
)
return sorted_importance