File size: 3,930 Bytes
6f5e4ab
 
654bcd9
6f5e4ab
 
 
 
 
 
654bcd9
 
 
 
 
 
 
6f5e4ab
 
654bcd9
 
 
 
 
 
6f5e4ab
 
 
654bcd9
6f5e4ab
 
654bcd9
 
6f5e4ab
654bcd9
6f5e4ab
 
 
 
 
 
654bcd9
6f5e4ab
 
 
 
 
 
 
 
 
 
 
 
654bcd9
6f5e4ab
 
 
654bcd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5e4ab
654bcd9
 
 
 
 
 
 
 
6f5e4ab
654bcd9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import shap
import numpy as np
import collections
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
from sklearn.linear_model import LinearRegression, LogisticRegression


class ExplainabilityEngine:
    def explain_tabular(self, model_pipeline, X_sample):
        """
        Computes global feature importance for tabular models using SHAP.
        Aggregates SHAP values from transformed features back to original features.
        """

        # -------------------- Validation --------------------
        if X_sample is None or X_sample.empty:
            raise ValueError("Sample data is empty, cannot compute explanations")

        if "preprocessor" not in model_pipeline.named_steps:
            raise ValueError("Pipeline missing 'preprocessor' step")

        if "model" not in model_pipeline.named_steps:
            raise ValueError("Pipeline missing 'model' step")

        preprocessor = model_pipeline.named_steps["preprocessor"]
        model = model_pipeline.named_steps["model"]

        # -------------------- Transform Data --------------------
        X_transformed = preprocessor.transform(X_sample)

        if X_transformed is None or X_transformed.shape[0] == 0:
            raise ValueError("Transformed data is empty after preprocessing")

        # -------------------- SHAP Explainer Selection --------------------
        if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
            explainer = shap.TreeExplainer(model)
            shap_values = explainer.shap_values(
                X_transformed, check_additivity=False
            )

            # For binary classification, SHAP returns list → take positive class
            if isinstance(shap_values, list):
                shap_values = shap_values[1]

        elif isinstance(model, (LinearRegression, LogisticRegression)):
            explainer = shap.LinearExplainer(model, X_transformed)
            shap_values = explainer.shap_values(X_transformed)

        else:
            raise ValueError(
                f"Explainability not supported for model type: {type(model)}"
            )

        # -------------------- Validate SHAP Output --------------------
        if shap_values is None or len(shap_values) == 0:
            raise ValueError("SHAP computation failed")

        if isinstance(shap_values, list):
            shap_values = np.array(shap_values)

        # -------------------- Aggregate Importance --------------------
        abs_shap = np.abs(shap_values)
        mean_shap = abs_shap.mean(axis=0)

        try:
            feature_names = preprocessor.get_feature_names_out()
        except Exception as e:
            raise ValueError(
                f"Failed to retrieve feature names from preprocessor: {e}"
            )

        if len(feature_names) != len(mean_shap):
            raise ValueError(
                "Mismatch between SHAP values and feature names"
            )

        # Aggregate encoded features back to original feature names
        aggregated_importance = collections.defaultdict(float)

        for feature_name, importance in zip(feature_names, mean_shap):
            # Examples:
            # num__WindSpeed
            # cat__PaymentMethod_CreditCard
            if "__" in feature_name:
                original_feature = feature_name.split("__")[1].split("_")[0]
            else:
                original_feature = feature_name

            aggregated_importance[original_feature] += float(importance)

        if not aggregated_importance:
            raise ValueError("No feature importance computed after aggregation")

        # -------------------- Sort + Limit Output --------------------
        sorted_importance = dict(
            sorted(
                aggregated_importance.items(),
                key=lambda x: x[1],
                reverse=True
            )[:10]
        )

        return sorted_importance