import numpy as np import shap # ------------------------------- # GLOBAL EXPLAINER CACHE # ------------------------------- _explainer = None def _get_explainer(model): global _explainer # create only once if _explainer is None: _explainer = shap.TreeExplainer(model) return _explainer # ------------------------------- # MAIN FUNCTION # ------------------------------- def shap_explain(model, sample, top_k=5): """ model : trained XGBoost model sample : numpy array shape (1, n_features) """ # use cached explainer explainer = _get_explainer(model) # Compute SHAP values shap_values = explainer.shap_values(sample) # Handle binary/multi output if isinstance(shap_values, list): shap_vals = shap_values[0] else: shap_vals = shap_values # ⭐ IMPORTANT FIX shap_vals = np.array(shap_vals).flatten() # Top important features top_indices = np.argsort(np.abs(shap_vals))[::-1][:top_k] explanation = [] for i in top_indices: impact = shap_vals[i] direction = "AI" if impact > 0 else "Human" explanation.append({ "feature_index": int(i), "impact": float(round(impact, 4)), "pushes_toward": direction }) return explanation