Spaces:
Build error
Build error
| from typing import List, Dict, Any | |
| import numpy as np | |
| import shap | |
| def _pick_explainer(model, X_background: np.ndarray): | |
| """ | |
| Choose an appropriate SHAP explainer. | |
| - TreeExplainer for tree-based models | |
| - LinearExplainer for linear models | |
| - KernelExplainer fallback (slow but general) | |
| """ | |
| try: | |
| import xgboost # noqa: F401 | |
| is_tree = hasattr(model, "get_booster") or "xgb" in type(model).__name__.lower() | |
| except Exception: | |
| is_tree = False | |
| is_tree = is_tree or any( | |
| s in type(model).__name__.lower() | |
| for s in ["randomforest", "gradientboost", "gbm", "lightgbm", "catboost"] | |
| ) | |
| if is_tree: | |
| return shap.TreeExplainer(model, feature_perturbation="tree_path_dependent") | |
| is_linear = "linear" in type(model).__name__.lower() or hasattr(model, "coef_") | |
| if is_linear: | |
| return shap.LinearExplainer(model, X_background) | |
| # Fallback for anything else | |
| return shap.KernelExplainer(model.predict, X_background) | |
| def explain_instance( | |
| model, | |
| x: np.ndarray, | |
| feature_names: List[str], | |
| background_X: np.ndarray, | |
| top_k: int = 8, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Compute SHAP for a single instance x (shape: (n_features,)). | |
| Always reduces SHAP output to a vector of length = n_features. | |
| Handles multiclass by averaging across classes. | |
| """ | |
| x = x.reshape(1, -1) | |
| explainer = _pick_explainer(model, background_X) | |
| values = explainer.shap_values(x) | |
| # SHAP returns different shapes depending on model type | |
| if isinstance(values, list): # multiclass -> list of arrays | |
| # stack into shape (n_classes, n_samples, n_features) | |
| values_arr = np.stack(values, axis=0) | |
| # average across classes -> shape (n_samples, n_features) | |
| values_arr = np.mean(values_arr, axis=0) | |
| else: | |
| values_arr = values # already (n_samples, n_features) | |
| # Always flatten to 1D vector | |
| shap_vec = np.array(values_arr[0]).reshape(-1) | |
| # Ensure length matches feature_names | |
| n_features = len(feature_names) | |
| if len(shap_vec) != n_features: | |
| shap_vec = shap_vec[:n_features] | |
| base_value = explainer.expected_value | |
| if isinstance(base_value, (list, np.ndarray)): | |
| base_value = float(np.mean(base_value)) | |
| # Top-k by absolute impact | |
| abs_imp = np.abs(shap_vec) | |
| idx = np.argsort(-abs_imp)[:top_k].ravel() | |
| top = [] | |
| for i in idx: | |
| i = int(i) | |
| if i >= n_features: # safety check | |
| continue | |
| shap_val = shap_vec[i] | |
| if isinstance(shap_val, (np.ndarray, list)): | |
| shap_val = float(np.mean(shap_val)) | |
| else: | |
| shap_val = float(shap_val) | |
| abs_val = abs_imp[i] | |
| if isinstance(abs_val, (np.ndarray, list)): | |
| abs_val = float(np.mean(abs_val)) | |
| else: | |
| abs_val = float(abs_val) | |
| top.append({ | |
| "feature": feature_names[i], | |
| "value": float(x[0, i]), | |
| "shap": shap_val, | |
| "abs_impact": abs_val, | |
| }) | |
| return { | |
| "shap_values": shap_vec.tolist(), | |
| "base_value": float(base_value), | |
| "topk": top, | |
| } | |