| import logging |
| import shap |
| import numpy as np |
| import pandas as pd |
| from datetime import datetime |
|
|
| from models.artifacts_loader import get_artifacts |
| from schemas.explain import ExplanationResponse, SHAPFeature, FeatureImportanceResponse |
| from utils.plot_utils import ( |
| generate_shap_summary_plot_base64, |
| plot_feature_importance_heatmap, |
| plotly_to_base64, |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ExplainerService: |
| @staticmethod |
| def explain_shap( |
| item: dict, top_n: int = 10, plot: bool = False |
| ) -> ExplanationResponse: |
| artifacts = get_artifacts() |
| if not artifacts.model: |
| raise RuntimeError("Model not loaded") |
|
|
| try: |
| |
| model = artifacts.model |
| df = model._to_dataframe([item]) |
| df_norm = model._normalize_boq_schema(df) |
|
|
| |
| |
| preprocessor = model.preprocessor |
| calibrator = model.calibrator |
|
|
| X_proc = preprocessor.transform(df_norm) |
| |
| |
| try: |
| pred_label = model.predict(df)[0] |
| target_class_index = list(calibrator.classes_).index(pred_label) |
| except Exception: |
| pred_label = "Unknown" |
| target_class_index = 0 |
|
|
| |
| |
| if artifacts.shap_background is None: |
| raise RuntimeError( |
| "SHAP background data is missing but required for KernelExplainer" |
| ) |
|
|
| explainer = shap.KernelExplainer( |
| calibrator.predict_proba, artifacts.shap_background |
| ) |
| shap_values_all = explainer.shap_values(X_proc, nsamples=100) |
|
|
| |
| if isinstance(shap_values_all, list): |
| if len(shap_values_all) > target_class_index: |
| shap_vals_target = shap_values_all[target_class_index] |
| else: |
| shap_vals_target = shap_values_all[0] |
| elif isinstance(shap_values_all, np.ndarray): |
| if len(shap_values_all.shape) == 3 and shap_values_all.shape[2] > target_class_index: |
| shap_vals_target = shap_values_all[:, :, target_class_index] |
| else: |
| shap_vals_target = shap_values_all |
| else: |
| shap_vals_target = shap_values_all |
|
|
| if len(shap_vals_target.shape) == 2: |
| sv = shap_vals_target[0] |
| else: |
| sv = shap_vals_target |
|
|
| |
| try: |
| if hasattr(preprocessor, "get_feature_names_out"): |
| feature_names = preprocessor.get_feature_names_out() |
| else: |
| feature_names = preprocessor.pipeline.get_feature_names_out() |
| except Exception: |
| feature_names = [f"fet_{i}" for i in range(len(sv))] |
|
|
| |
| abs_importance = np.abs(sv) |
| top_indices = np.argsort(abs_importance)[-top_n:][::-1] |
|
|
| top_features = [ |
| SHAPFeature(feature=str(feature_names[i]), value=float(sv[i])) |
| for i in top_indices |
| ] |
|
|
| |
| plot_b64 = None |
| if plot: |
| plot_b64 = generate_shap_summary_plot_base64( |
| shap_vals_target, X_proc, feature_names=feature_names, target_class=pred_label |
| ) |
|
|
| return ExplanationResponse( |
| top_features=top_features, |
| shap_plot_base64=plot_b64, |
| metadata={ |
| "method": "SHAP KernelExplainer", |
| "timestamp": datetime.utcnow().isoformat(), |
| }, |
| ) |
| except Exception as e: |
| logger.exception("Error during SHAP explanation") |
| raise RuntimeError(f"Explanation failed: {e}") |
|
|
| @staticmethod |
| def explain_feature_importance( |
| top_n: int = 30, skip_top: int = 0 |
| ) -> FeatureImportanceResponse: |
| artifacts = get_artifacts() |
| if not artifacts.model: |
| raise RuntimeError("Model not loaded") |
|
|
| try: |
| fig, df_plot = plot_feature_importance_heatmap( |
| model=artifacts.model, |
| top_n=top_n, |
| skip_top=skip_top, |
| title="Global Feature Importance", |
| ) |
| plot_base64 = plotly_to_base64(fig) |
|
|
| return FeatureImportanceResponse( |
| plot_base64=plot_base64, |
| metadata={ |
| "method": "Averaged Calibrated Feature Importances", |
| "timestamp": datetime.utcnow().isoformat(), |
| "top_n": top_n, |
| "skip_top": skip_top, |
| }, |
| ) |
| except Exception as e: |
| logger.exception("Error during feature importance generation") |
| raise RuntimeError(f"Feature importance generation failed: {e}") |
|
|