import logging import shap import numpy as np import pandas as pd from datetime import datetime from models.artifacts_loader import get_artifacts from schemas.explain import ExplanationResponse, SHAPFeature, FeatureImportanceResponse from utils.plot_utils import ( generate_shap_summary_plot_base64, plot_feature_importance_heatmap, plotly_to_base64, ) logger = logging.getLogger(__name__) class ExplainerService: @staticmethod def explain_shap( item: dict, top_n: int = 10, plot: bool = False ) -> ExplanationResponse: artifacts = get_artifacts() if not artifacts.model: raise RuntimeError("Model not loaded") try: # Prepare Input model = artifacts.model df = model._to_dataframe([item]) df_norm = model._normalize_boq_schema(df) # The background data for SHAP is expected to be transformed data. # Explaining the inner classifier using transformed input. preprocessor = model.preprocessor calibrator = model.calibrator X_proc = preprocessor.transform(df_norm) # Predict to find the target class index try: pred_label = model.predict(df)[0] target_class_index = list(calibrator.classes_).index(pred_label) except Exception: pred_label = "Unknown" target_class_index = 0 # 2. Compute SHAP # Using KernelExplainer because CalibratedClassifierCV doesn't support TreeExplainer natively easily if artifacts.shap_background is None: raise RuntimeError( "SHAP background data is missing but required for KernelExplainer" ) explainer = shap.KernelExplainer( calibrator.predict_proba, artifacts.shap_background ) shap_values_all = explainer.shap_values(X_proc, nsamples=100) # Robustly extract values for the target class if isinstance(shap_values_all, list): if len(shap_values_all) > target_class_index: shap_vals_target = shap_values_all[target_class_index] else: shap_vals_target = shap_values_all[0] elif isinstance(shap_values_all, np.ndarray): if len(shap_values_all.shape) == 3 and shap_values_all.shape[2] > target_class_index: shap_vals_target = shap_values_all[:, :, target_class_index] else: shap_vals_target = shap_values_all else: shap_vals_target = shap_values_all if len(shap_vals_target.shape) == 2: sv = shap_vals_target[0] else: sv = shap_vals_target # 3. Features Names try: if hasattr(preprocessor, "get_feature_names_out"): feature_names = preprocessor.get_feature_names_out() else: feature_names = preprocessor.pipeline.get_feature_names_out() except Exception: feature_names = [f"fet_{i}" for i in range(len(sv))] # 4. Top Features abs_importance = np.abs(sv) top_indices = np.argsort(abs_importance)[-top_n:][::-1] top_features = [ SHAPFeature(feature=str(feature_names[i]), value=float(sv[i])) for i in top_indices ] # 5. Plot plot_b64 = None if plot: plot_b64 = generate_shap_summary_plot_base64( shap_vals_target, X_proc, feature_names=feature_names, target_class=pred_label ) return ExplanationResponse( top_features=top_features, shap_plot_base64=plot_b64, metadata={ "method": "SHAP KernelExplainer", "timestamp": datetime.utcnow().isoformat(), }, ) except Exception as e: logger.exception("Error during SHAP explanation") raise RuntimeError(f"Explanation failed: {e}") @staticmethod def explain_feature_importance( top_n: int = 30, skip_top: int = 0 ) -> FeatureImportanceResponse: artifacts = get_artifacts() if not artifacts.model: raise RuntimeError("Model not loaded") try: fig, df_plot = plot_feature_importance_heatmap( model=artifacts.model, top_n=top_n, skip_top=skip_top, title="Global Feature Importance", ) plot_base64 = plotly_to_base64(fig) return FeatureImportanceResponse( plot_base64=plot_base64, metadata={ "method": "Averaged Calibrated Feature Importances", "timestamp": datetime.utcnow().isoformat(), "top_n": top_n, "skip_top": skip_top, }, ) except Exception as e: logger.exception("Error during feature importance generation") raise RuntimeError(f"Feature importance generation failed: {e}")