File size: 3,180 Bytes
2ae10e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import List, Dict, Any
import numpy as np
import shap


def _pick_explainer(model, X_background: np.ndarray):
    """
    Choose an appropriate SHAP explainer.
    - TreeExplainer for tree-based models
    - LinearExplainer for linear models
    - KernelExplainer fallback (slow but general)
    """
    try:
        import xgboost  # noqa: F401
        is_tree = hasattr(model, "get_booster") or "xgb" in type(model).__name__.lower()
    except Exception:
        is_tree = False

    is_tree = is_tree or any(
        s in type(model).__name__.lower()
        for s in ["randomforest", "gradientboost", "gbm", "lightgbm", "catboost"]
    )

    if is_tree:
        return shap.TreeExplainer(model, feature_perturbation="tree_path_dependent")

    is_linear = "linear" in type(model).__name__.lower() or hasattr(model, "coef_")
    if is_linear:
        return shap.LinearExplainer(model, X_background)

    # Fallback for anything else
    return shap.KernelExplainer(model.predict, X_background)


def explain_instance(
    model,
    x: np.ndarray,
    feature_names: List[str],
    background_X: np.ndarray,
    top_k: int = 8,
) -> Dict[str, Any]:
    """
    Compute SHAP for a single instance x (shape: (n_features,)).
    Always reduces SHAP output to a vector of length = n_features.
    Handles multiclass by averaging across classes.
    """
    x = x.reshape(1, -1)
    explainer = _pick_explainer(model, background_X)

    values = explainer.shap_values(x)

    # SHAP returns different shapes depending on model type
    if isinstance(values, list):  # multiclass -> list of arrays
        # stack into shape (n_classes, n_samples, n_features)
        values_arr = np.stack(values, axis=0)
        # average across classes -> shape (n_samples, n_features)
        values_arr = np.mean(values_arr, axis=0)
    else:
        values_arr = values  # already (n_samples, n_features)

    # Always flatten to 1D vector
    shap_vec = np.array(values_arr[0]).reshape(-1)

    # Ensure length matches feature_names
    n_features = len(feature_names)
    if len(shap_vec) != n_features:
        shap_vec = shap_vec[:n_features]

    base_value = explainer.expected_value
    if isinstance(base_value, (list, np.ndarray)):
        base_value = float(np.mean(base_value))

    # Top-k by absolute impact
    abs_imp = np.abs(shap_vec)
    idx = np.argsort(-abs_imp)[:top_k].ravel()

    top = []
    for i in idx:
        i = int(i)
        if i >= n_features:  # safety check
            continue

        shap_val = shap_vec[i]
        if isinstance(shap_val, (np.ndarray, list)):
            shap_val = float(np.mean(shap_val))
        else:
            shap_val = float(shap_val)

        abs_val = abs_imp[i]
        if isinstance(abs_val, (np.ndarray, list)):
            abs_val = float(np.mean(abs_val))
        else:
            abs_val = float(abs_val)

        top.append({
            "feature": feature_names[i],
            "value": float(x[0, i]),
            "shap": shap_val,
            "abs_impact": abs_val,
        })

    return {
        "shap_values": shap_vec.tolist(),
        "base_value": float(base_value),
        "topk": top,
    }