ACA050 commited on
Commit
654bcd9
·
verified ·
1 Parent(s): 6f5e4ab

Update backend/core/explainability.py

Browse files
Files changed (1) hide show
  1. backend/core/explainability.py +64 -9
backend/core/explainability.py CHANGED
@@ -1,30 +1,44 @@
1
  import shap
2
  import numpy as np
 
3
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
4
  from sklearn.linear_model import LinearRegression, LogisticRegression
5
 
6
 
7
  class ExplainabilityEngine:
8
  def explain_tabular(self, model_pipeline, X_sample):
9
- if X_sample.empty:
 
 
 
 
 
 
10
  raise ValueError("Sample data is empty, cannot compute explanations")
11
 
 
 
 
 
 
 
12
  preprocessor = model_pipeline.named_steps["preprocessor"]
13
  model = model_pipeline.named_steps["model"]
14
 
 
15
  X_transformed = preprocessor.transform(X_sample)
16
 
17
- if X_transformed.shape[0] == 0:
18
- raise ValueError("Transformed sample data is empty after preprocessing")
19
 
20
- # -------- Model-aware SHAP selection --------
21
  if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
22
  explainer = shap.TreeExplainer(model)
23
  shap_values = explainer.shap_values(
24
  X_transformed, check_additivity=False
25
  )
26
 
27
- # Classification returns list
28
  if isinstance(shap_values, list):
29
  shap_values = shap_values[1]
30
 
@@ -37,12 +51,53 @@ class ExplainabilityEngine:
37
  f"Explainability not supported for model type: {type(model)}"
38
  )
39
 
 
40
  if shap_values is None or len(shap_values) == 0:
41
  raise ValueError("SHAP computation failed")
42
 
43
- global_importance = np.abs(shap_values).mean(axis=0).tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- if not global_importance:
46
- raise ValueError("No feature importance computed")
 
 
 
 
 
 
47
 
48
- return global_importance
 
1
  import shap
2
  import numpy as np
3
+ import collections
4
  from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
5
  from sklearn.linear_model import LinearRegression, LogisticRegression
6
 
7
 
8
  class ExplainabilityEngine:
9
  def explain_tabular(self, model_pipeline, X_sample):
10
+ """
11
+ Computes global feature importance for tabular models using SHAP.
12
+ Aggregates SHAP values from transformed features back to original features.
13
+ """
14
+
15
+ # -------------------- Validation --------------------
16
+ if X_sample is None or X_sample.empty:
17
  raise ValueError("Sample data is empty, cannot compute explanations")
18
 
19
+ if "preprocessor" not in model_pipeline.named_steps:
20
+ raise ValueError("Pipeline missing 'preprocessor' step")
21
+
22
+ if "model" not in model_pipeline.named_steps:
23
+ raise ValueError("Pipeline missing 'model' step")
24
+
25
  preprocessor = model_pipeline.named_steps["preprocessor"]
26
  model = model_pipeline.named_steps["model"]
27
 
28
+ # -------------------- Transform Data --------------------
29
  X_transformed = preprocessor.transform(X_sample)
30
 
31
+ if X_transformed is None or X_transformed.shape[0] == 0:
32
+ raise ValueError("Transformed data is empty after preprocessing")
33
 
34
+ # -------------------- SHAP Explainer Selection --------------------
35
  if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
36
  explainer = shap.TreeExplainer(model)
37
  shap_values = explainer.shap_values(
38
  X_transformed, check_additivity=False
39
  )
40
 
41
+ # For binary classification, SHAP returns list → take positive class
42
  if isinstance(shap_values, list):
43
  shap_values = shap_values[1]
44
 
 
51
  f"Explainability not supported for model type: {type(model)}"
52
  )
53
 
54
+ # -------------------- Validate SHAP Output --------------------
55
  if shap_values is None or len(shap_values) == 0:
56
  raise ValueError("SHAP computation failed")
57
 
58
+ if isinstance(shap_values, list):
59
+ shap_values = np.array(shap_values)
60
+
61
+ # -------------------- Aggregate Importance --------------------
62
+ abs_shap = np.abs(shap_values)
63
+ mean_shap = abs_shap.mean(axis=0)
64
+
65
+ try:
66
+ feature_names = preprocessor.get_feature_names_out()
67
+ except Exception as e:
68
+ raise ValueError(
69
+ f"Failed to retrieve feature names from preprocessor: {e}"
70
+ )
71
+
72
+ if len(feature_names) != len(mean_shap):
73
+ raise ValueError(
74
+ "Mismatch between SHAP values and feature names"
75
+ )
76
+
77
+ # Aggregate encoded features back to original feature names
78
+ aggregated_importance = collections.defaultdict(float)
79
+
80
+ for feature_name, importance in zip(feature_names, mean_shap):
81
+ # Examples:
82
+ # num__WindSpeed
83
+ # cat__PaymentMethod_CreditCard
84
+ if "__" in feature_name:
85
+ original_feature = feature_name.split("__")[1].split("_")[0]
86
+ else:
87
+ original_feature = feature_name
88
+
89
+ aggregated_importance[original_feature] += float(importance)
90
+
91
+ if not aggregated_importance:
92
+ raise ValueError("No feature importance computed after aggregation")
93
 
94
+ # -------------------- Sort + Limit Output --------------------
95
+ sorted_importance = dict(
96
+ sorted(
97
+ aggregated_importance.items(),
98
+ key=lambda x: x[1],
99
+ reverse=True
100
+ )[:10]
101
+ )
102
 
103
+ return sorted_importance