D0nG4667
Fix: Extract exact feature names in DL explainer and resolve Kaleido Docker dependency
94e633e
import shap
import logging
import numpy as np
from datetime import datetime
from models.artifacts_loader import get_artifacts
from schemas.patient import Patient, ExplainResponse
from utils.preprocessing import preprocess_input
from utils.shap_utils import generate_shap_summary_plot_base64
from core.exceptions import ModelPredictionError
logger = logging.getLogger(__name__)
class ExplainerService:
def __init__(self):
self.artifacts = get_artifacts()
self.explainer = None
def _get_explainer(self):
"""
Lazy load explainer to save startup time or handle failures.
"""
if self.explainer:
return self.explainer
if not self.artifacts.model or self.artifacts.background_data is None:
return None
try:
# Try DeepExplainer
# Note: DeepExplainer requires the model and background data (numpy array)
# If background_data is DF, convert to numpy
bg_data = self.artifacts.background_data
if hasattr(bg_data, "values"):
bg_data = bg_data.values
# Usually we pass a summary (kmeans) of background data to speed up
# e.g. shap.sample(bg_data, 100) or kmeans
# For now use as is
self.explainer = shap.DeepExplainer(self.artifacts.model, bg_data)
return self.explainer
except Exception as e:
logger.warning(f"DeepExplainer failed initialization: {e}. Falling back to KernelExplainer (not implemented for DL speed reasons usually).")
# Fallback could go here
return None
async def explain(self, patient: Patient, top_n: int = 10, plot: bool = False) -> ExplainResponse:
if not self.artifacts.is_loaded:
raise ModelPredictionError("Model artifacts not loaded")
features = preprocess_input(
patient,
feature_creator=self.artifacts.feature_creator,
preprocessor=self.artifacts.preprocessor
)
explainer = self._get_explainer()
if not explainer:
raise ModelPredictionError("SHAP Explainer could not be initialized")
try:
shap_values = explainer.shap_values(features)
# shap_values is a list for functional models multi-output, or array.
# Handling 1 output node
if isinstance(shap_values, list):
sv = shap_values[0]
else:
sv = shap_values
# Top Features
# sv is (1, n_features)
sv_flat = sv.flatten()
# 3. Features Names Extraction (to match ML System)
feature_names = None
try:
# 1st try: preprocessor (Scaling/Encoding)
if self.artifacts.preprocessor and hasattr(self.artifacts.preprocessor, "get_feature_names_out"):
feature_names = self.artifacts.preprocessor.get_feature_names_out()
# 2nd try: feature_creator (Pipeline)
elif self.artifacts.feature_creator and hasattr(self.artifacts.feature_creator, "get_feature_names_out"):
feature_names = self.artifacts.feature_creator.get_feature_names_out()
except Exception as e:
logger.warning(f"Could not get feature names from artifacts: {e}")
if feature_names is None:
# Fallback to patient object keys
feature_list = list(patient.model_dump(exclude={'include_shap'}).keys())
# If mismatch, use "Feature N"
if len(feature_list) != len(sv_flat):
feature_list = [f"Feature {i}" for i in range(len(sv_flat))]
else:
feature_list = list(feature_names)
# Sort by abs value
indices = np.argsort(-np.abs(sv_flat))[:top_n] # Use user provided top_n
top_features = [
{"feature": str(feature_list[i]), "value": float(sv_flat[i])}
for i in indices
]
# Plot
plot_base64 = None
if plot:
# New Plotly utility
# It expects (shap_values, X, feature_names)
# shap_values should be just the values for the features (1D array)
plot_base64 = generate_shap_summary_plot_base64(sv_flat, features, feature_list)
return ExplainResponse(
top_features=top_features,
shap_plot_base64=plot_base64,
metadata={
"method": "SHAP",
"timestamp": datetime.utcnow().isoformat()
}
)
except Exception as e:
logger.exception("Explanation generation failed")
raise ModelPredictionError(f"Explanation failed: {e}")
explainer_service = ExplainerService()
def get_explainer_service() -> ExplainerService:
return explainer_service