Spaces:
Build error
Build error
File size: 3,180 Bytes
2ae10e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | from typing import List, Dict, Any
import numpy as np
import shap
def _pick_explainer(model, X_background: np.ndarray):
"""
Choose an appropriate SHAP explainer.
- TreeExplainer for tree-based models
- LinearExplainer for linear models
- KernelExplainer fallback (slow but general)
"""
try:
import xgboost # noqa: F401
is_tree = hasattr(model, "get_booster") or "xgb" in type(model).__name__.lower()
except Exception:
is_tree = False
is_tree = is_tree or any(
s in type(model).__name__.lower()
for s in ["randomforest", "gradientboost", "gbm", "lightgbm", "catboost"]
)
if is_tree:
return shap.TreeExplainer(model, feature_perturbation="tree_path_dependent")
is_linear = "linear" in type(model).__name__.lower() or hasattr(model, "coef_")
if is_linear:
return shap.LinearExplainer(model, X_background)
# Fallback for anything else
return shap.KernelExplainer(model.predict, X_background)
def explain_instance(
model,
x: np.ndarray,
feature_names: List[str],
background_X: np.ndarray,
top_k: int = 8,
) -> Dict[str, Any]:
"""
Compute SHAP for a single instance x (shape: (n_features,)).
Always reduces SHAP output to a vector of length = n_features.
Handles multiclass by averaging across classes.
"""
x = x.reshape(1, -1)
explainer = _pick_explainer(model, background_X)
values = explainer.shap_values(x)
# SHAP returns different shapes depending on model type
if isinstance(values, list): # multiclass -> list of arrays
# stack into shape (n_classes, n_samples, n_features)
values_arr = np.stack(values, axis=0)
# average across classes -> shape (n_samples, n_features)
values_arr = np.mean(values_arr, axis=0)
else:
values_arr = values # already (n_samples, n_features)
# Always flatten to 1D vector
shap_vec = np.array(values_arr[0]).reshape(-1)
# Ensure length matches feature_names
n_features = len(feature_names)
if len(shap_vec) != n_features:
shap_vec = shap_vec[:n_features]
base_value = explainer.expected_value
if isinstance(base_value, (list, np.ndarray)):
base_value = float(np.mean(base_value))
# Top-k by absolute impact
abs_imp = np.abs(shap_vec)
idx = np.argsort(-abs_imp)[:top_k].ravel()
top = []
for i in idx:
i = int(i)
if i >= n_features: # safety check
continue
shap_val = shap_vec[i]
if isinstance(shap_val, (np.ndarray, list)):
shap_val = float(np.mean(shap_val))
else:
shap_val = float(shap_val)
abs_val = abs_imp[i]
if isinstance(abs_val, (np.ndarray, list)):
abs_val = float(np.mean(abs_val))
else:
abs_val = float(abs_val)
top.append({
"feature": feature_names[i],
"value": float(x[0, i]),
"shap": shap_val,
"abs_impact": abs_val,
})
return {
"shap_values": shap_vec.tolist(),
"base_value": float(base_value),
"topk": top,
}
|