ACA050 commited on
Commit
6f5e4ab
·
verified ·
1 Parent(s): 044c9a2

Update backend/core/explainability.py

Browse files
Files changed (1) hide show
  1. backend/core/explainability.py +48 -33
backend/core/explainability.py CHANGED
@@ -1,33 +1,48 @@
1
- import shap
2
- import numpy as np
3
-
4
- class ExplainabilityEngine:
5
- def explain_tabular(self, model_pipeline, X_sample):
6
- if X_sample.empty:
7
- raise ValueError("Sample data is empty, cannot compute explanations")
8
-
9
- # Extract trained model and preprocessor
10
- preprocessor = model_pipeline.named_steps["preprocessor"]
11
- model = model_pipeline.named_steps["model"]
12
-
13
- X_transformed = preprocessor.transform(X_sample)
14
-
15
- if X_transformed.shape[0] == 0:
16
- raise ValueError("Transformed sample data is empty after preprocessing")
17
-
18
- explainer = shap.Explainer(model, X_transformed)
19
- shap_values = explainer(X_transformed, check_additivity=False)
20
-
21
- if shap_values is None or shap_values.values is None:
22
- raise ValueError("SHAP computation failed")
23
-
24
- global_importance = np.abs(shap_values.values).mean(axis=0).tolist()
25
-
26
- if len(global_importance) == 0:
27
- raise ValueError("No feature importance computed")
28
-
29
- return global_importance
30
-
31
-
32
-
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shap
2
+ import numpy as np
3
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
4
+ from sklearn.linear_model import LinearRegression, LogisticRegression
5
+
6
+
7
+ class ExplainabilityEngine:
8
+ def explain_tabular(self, model_pipeline, X_sample):
9
+ if X_sample.empty:
10
+ raise ValueError("Sample data is empty, cannot compute explanations")
11
+
12
+ preprocessor = model_pipeline.named_steps["preprocessor"]
13
+ model = model_pipeline.named_steps["model"]
14
+
15
+ X_transformed = preprocessor.transform(X_sample)
16
+
17
+ if X_transformed.shape[0] == 0:
18
+ raise ValueError("Transformed sample data is empty after preprocessing")
19
+
20
+ # -------- Model-aware SHAP selection --------
21
+ if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
22
+ explainer = shap.TreeExplainer(model)
23
+ shap_values = explainer.shap_values(
24
+ X_transformed, check_additivity=False
25
+ )
26
+
27
+ # Classification returns list
28
+ if isinstance(shap_values, list):
29
+ shap_values = shap_values[1]
30
+
31
+ elif isinstance(model, (LinearRegression, LogisticRegression)):
32
+ explainer = shap.LinearExplainer(model, X_transformed)
33
+ shap_values = explainer.shap_values(X_transformed)
34
+
35
+ else:
36
+ raise ValueError(
37
+ f"Explainability not supported for model type: {type(model)}"
38
+ )
39
+
40
+ if shap_values is None or len(shap_values) == 0:
41
+ raise ValueError("SHAP computation failed")
42
+
43
+ global_importance = np.abs(shap_values).mean(axis=0).tolist()
44
+
45
+ if not global_importance:
46
+ raise ValueError("No feature importance computed")
47
+
48
+ return global_importance