D0nG4667
Fix: Extract exact feature names in DL explainer and resolve Kaleido Docker dependency
94e633e | import shap | |
| import logging | |
| import numpy as np | |
| from datetime import datetime | |
| from models.artifacts_loader import get_artifacts | |
| from schemas.patient import Patient, ExplainResponse | |
| from utils.preprocessing import preprocess_input | |
| from utils.shap_utils import generate_shap_summary_plot_base64 | |
| from core.exceptions import ModelPredictionError | |
| logger = logging.getLogger(__name__) | |
| class ExplainerService: | |
| def __init__(self): | |
| self.artifacts = get_artifacts() | |
| self.explainer = None | |
| def _get_explainer(self): | |
| """ | |
| Lazy load explainer to save startup time or handle failures. | |
| """ | |
| if self.explainer: | |
| return self.explainer | |
| if not self.artifacts.model or self.artifacts.background_data is None: | |
| return None | |
| try: | |
| # Try DeepExplainer | |
| # Note: DeepExplainer requires the model and background data (numpy array) | |
| # If background_data is DF, convert to numpy | |
| bg_data = self.artifacts.background_data | |
| if hasattr(bg_data, "values"): | |
| bg_data = bg_data.values | |
| # Usually we pass a summary (kmeans) of background data to speed up | |
| # e.g. shap.sample(bg_data, 100) or kmeans | |
| # For now use as is | |
| self.explainer = shap.DeepExplainer(self.artifacts.model, bg_data) | |
| return self.explainer | |
| except Exception as e: | |
| logger.warning(f"DeepExplainer failed initialization: {e}. Falling back to KernelExplainer (not implemented for DL speed reasons usually).") | |
| # Fallback could go here | |
| return None | |
| async def explain(self, patient: Patient, top_n: int = 10, plot: bool = False) -> ExplainResponse: | |
| if not self.artifacts.is_loaded: | |
| raise ModelPredictionError("Model artifacts not loaded") | |
| features = preprocess_input( | |
| patient, | |
| feature_creator=self.artifacts.feature_creator, | |
| preprocessor=self.artifacts.preprocessor | |
| ) | |
| explainer = self._get_explainer() | |
| if not explainer: | |
| raise ModelPredictionError("SHAP Explainer could not be initialized") | |
| try: | |
| shap_values = explainer.shap_values(features) | |
| # shap_values is a list for functional models multi-output, or array. | |
| # Handling 1 output node | |
| if isinstance(shap_values, list): | |
| sv = shap_values[0] | |
| else: | |
| sv = shap_values | |
| # Top Features | |
| # sv is (1, n_features) | |
| sv_flat = sv.flatten() | |
| # 3. Features Names Extraction (to match ML System) | |
| feature_names = None | |
| try: | |
| # 1st try: preprocessor (Scaling/Encoding) | |
| if self.artifacts.preprocessor and hasattr(self.artifacts.preprocessor, "get_feature_names_out"): | |
| feature_names = self.artifacts.preprocessor.get_feature_names_out() | |
| # 2nd try: feature_creator (Pipeline) | |
| elif self.artifacts.feature_creator and hasattr(self.artifacts.feature_creator, "get_feature_names_out"): | |
| feature_names = self.artifacts.feature_creator.get_feature_names_out() | |
| except Exception as e: | |
| logger.warning(f"Could not get feature names from artifacts: {e}") | |
| if feature_names is None: | |
| # Fallback to patient object keys | |
| feature_list = list(patient.model_dump(exclude={'include_shap'}).keys()) | |
| # If mismatch, use "Feature N" | |
| if len(feature_list) != len(sv_flat): | |
| feature_list = [f"Feature {i}" for i in range(len(sv_flat))] | |
| else: | |
| feature_list = list(feature_names) | |
| # Sort by abs value | |
| indices = np.argsort(-np.abs(sv_flat))[:top_n] # Use user provided top_n | |
| top_features = [ | |
| {"feature": str(feature_list[i]), "value": float(sv_flat[i])} | |
| for i in indices | |
| ] | |
| # Plot | |
| plot_base64 = None | |
| if plot: | |
| # New Plotly utility | |
| # It expects (shap_values, X, feature_names) | |
| # shap_values should be just the values for the features (1D array) | |
| plot_base64 = generate_shap_summary_plot_base64(sv_flat, features, feature_list) | |
| return ExplainResponse( | |
| top_features=top_features, | |
| shap_plot_base64=plot_base64, | |
| metadata={ | |
| "method": "SHAP", | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| ) | |
| except Exception as e: | |
| logger.exception("Explanation generation failed") | |
| raise ModelPredictionError(f"Explanation failed: {e}") | |
| explainer_service = ExplainerService() | |
| def get_explainer_service() -> ExplainerService: | |
| return explainer_service | |