Spaces:
Sleeping
Sleeping
Update backend/core/explainability.py
Browse files
backend/core/explainability.py
CHANGED
|
@@ -1,30 +1,44 @@
|
|
| 1 |
import shap
|
| 2 |
import numpy as np
|
|
|
|
| 3 |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
| 4 |
from sklearn.linear_model import LinearRegression, LogisticRegression
|
| 5 |
|
| 6 |
|
| 7 |
class ExplainabilityEngine:
|
| 8 |
def explain_tabular(self, model_pipeline, X_sample):
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
raise ValueError("Sample data is empty, cannot compute explanations")
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
preprocessor = model_pipeline.named_steps["preprocessor"]
|
| 13 |
model = model_pipeline.named_steps["model"]
|
| 14 |
|
|
|
|
| 15 |
X_transformed = preprocessor.transform(X_sample)
|
| 16 |
|
| 17 |
-
if X_transformed.shape[0] == 0:
|
| 18 |
-
raise ValueError("Transformed
|
| 19 |
|
| 20 |
-
#
|
| 21 |
if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
|
| 22 |
explainer = shap.TreeExplainer(model)
|
| 23 |
shap_values = explainer.shap_values(
|
| 24 |
X_transformed, check_additivity=False
|
| 25 |
)
|
| 26 |
|
| 27 |
-
#
|
| 28 |
if isinstance(shap_values, list):
|
| 29 |
shap_values = shap_values[1]
|
| 30 |
|
|
@@ -37,12 +51,53 @@ class ExplainabilityEngine:
|
|
| 37 |
f"Explainability not supported for model type: {type(model)}"
|
| 38 |
)
|
| 39 |
|
|
|
|
| 40 |
if shap_values is None or len(shap_values) == 0:
|
| 41 |
raise ValueError("SHAP computation failed")
|
| 42 |
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
return
|
|
|
|
| 1 |
import shap
|
| 2 |
import numpy as np
|
| 3 |
+
import collections
|
| 4 |
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
| 5 |
from sklearn.linear_model import LinearRegression, LogisticRegression
|
| 6 |
|
| 7 |
|
| 8 |
class ExplainabilityEngine:
|
| 9 |
def explain_tabular(self, model_pipeline, X_sample):
|
| 10 |
+
"""
|
| 11 |
+
Computes global feature importance for tabular models using SHAP.
|
| 12 |
+
Aggregates SHAP values from transformed features back to original features.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
# -------------------- Validation --------------------
|
| 16 |
+
if X_sample is None or X_sample.empty:
|
| 17 |
raise ValueError("Sample data is empty, cannot compute explanations")
|
| 18 |
|
| 19 |
+
if "preprocessor" not in model_pipeline.named_steps:
|
| 20 |
+
raise ValueError("Pipeline missing 'preprocessor' step")
|
| 21 |
+
|
| 22 |
+
if "model" not in model_pipeline.named_steps:
|
| 23 |
+
raise ValueError("Pipeline missing 'model' step")
|
| 24 |
+
|
| 25 |
preprocessor = model_pipeline.named_steps["preprocessor"]
|
| 26 |
model = model_pipeline.named_steps["model"]
|
| 27 |
|
| 28 |
+
# -------------------- Transform Data --------------------
|
| 29 |
X_transformed = preprocessor.transform(X_sample)
|
| 30 |
|
| 31 |
+
if X_transformed is None or X_transformed.shape[0] == 0:
|
| 32 |
+
raise ValueError("Transformed data is empty after preprocessing")
|
| 33 |
|
| 34 |
+
# -------------------- SHAP Explainer Selection --------------------
|
| 35 |
if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
|
| 36 |
explainer = shap.TreeExplainer(model)
|
| 37 |
shap_values = explainer.shap_values(
|
| 38 |
X_transformed, check_additivity=False
|
| 39 |
)
|
| 40 |
|
| 41 |
+
# For binary classification, SHAP returns list → take positive class
|
| 42 |
if isinstance(shap_values, list):
|
| 43 |
shap_values = shap_values[1]
|
| 44 |
|
|
|
|
| 51 |
f"Explainability not supported for model type: {type(model)}"
|
| 52 |
)
|
| 53 |
|
| 54 |
+
# -------------------- Validate SHAP Output --------------------
|
| 55 |
if shap_values is None or len(shap_values) == 0:
|
| 56 |
raise ValueError("SHAP computation failed")
|
| 57 |
|
| 58 |
+
if isinstance(shap_values, list):
|
| 59 |
+
shap_values = np.array(shap_values)
|
| 60 |
+
|
| 61 |
+
# -------------------- Aggregate Importance --------------------
|
| 62 |
+
abs_shap = np.abs(shap_values)
|
| 63 |
+
mean_shap = abs_shap.mean(axis=0)
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
feature_names = preprocessor.get_feature_names_out()
|
| 67 |
+
except Exception as e:
|
| 68 |
+
raise ValueError(
|
| 69 |
+
f"Failed to retrieve feature names from preprocessor: {e}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
if len(feature_names) != len(mean_shap):
|
| 73 |
+
raise ValueError(
|
| 74 |
+
"Mismatch between SHAP values and feature names"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Aggregate encoded features back to original feature names
|
| 78 |
+
aggregated_importance = collections.defaultdict(float)
|
| 79 |
+
|
| 80 |
+
for feature_name, importance in zip(feature_names, mean_shap):
|
| 81 |
+
# Examples:
|
| 82 |
+
# num__WindSpeed
|
| 83 |
+
# cat__PaymentMethod_CreditCard
|
| 84 |
+
if "__" in feature_name:
|
| 85 |
+
original_feature = feature_name.split("__")[1].split("_")[0]
|
| 86 |
+
else:
|
| 87 |
+
original_feature = feature_name
|
| 88 |
+
|
| 89 |
+
aggregated_importance[original_feature] += float(importance)
|
| 90 |
+
|
| 91 |
+
if not aggregated_importance:
|
| 92 |
+
raise ValueError("No feature importance computed after aggregation")
|
| 93 |
|
| 94 |
+
# -------------------- Sort + Limit Output --------------------
|
| 95 |
+
sorted_importance = dict(
|
| 96 |
+
sorted(
|
| 97 |
+
aggregated_importance.items(),
|
| 98 |
+
key=lambda x: x[1],
|
| 99 |
+
reverse=True
|
| 100 |
+
)[:10]
|
| 101 |
+
)
|
| 102 |
|
| 103 |
+
return sorted_importance
|