import base64 import io import shap import pandas as pd import numpy as np from typing import Optional, Tuple import plotly.express as px import plotly.graph_objects as go import plotly.io as pio def generate_shap_summary_plot_base64(shap_values, X_proc, feature_names=None, target_class=None) -> str: """ Generates a SHAP summary plot using Plotly (Strip Plot) and returns it as a base64 string. """ try: # 1. Prepare Feature Names if feature_names is None: if hasattr(X_proc, "columns"): feature_names = list(X_proc.columns) else: feature_names = [f"Feature {i}" for i in range(X_proc.shape[1] if hasattr(X_proc, "shape") else len(X_proc[0]))] # 2. Handle SHAP values input (ensure it's 1D for single instance or handle multiple) # ExplainerService passes shap_vals_target which is typically (n_features,) for single prediction vals = shap_values if isinstance(vals, list): vals = vals[1] if len(vals) > 1 else vals[0] if hasattr(vals, "shape"): if len(vals.shape) == 2 and vals.shape[0] == 1: vals = vals[0] # Flatten (1, features) -> (features,) # 3. Create DataFrame # If vals is 1D (n,), we treat it as 1 sample. # px.strip expects a distribution, but for 1 sample it works as dot plot. df_plot = pd.DataFrame({ "Feature": feature_names, "SHAP": vals }) # Add coloring based on impact direction (Risk/Protective) df_plot["Type"] = ["Risk (Positive)" if v > 0 else "Protective (Negative)" for v in vals] # Sort features by absolute SHAP value (Importance) df_plot["AbsSHAP"] = df_plot["SHAP"].abs() df_plot = df_plot.sort_values("AbsSHAP", ascending=True) # Ascending for correct Y-axis order in Plotly # 4. Generate Plotly Strip Plot fig = px.strip( df_plot, x='SHAP', y='Feature', color='Type', stripmode='overlay', color_discrete_map={ "Risk (Positive)": "#ef4444", "Protective (Negative)": "#10b981" }, title=f"SHAP Impact Analysis{f' (Predicted: {target_class})' if target_class else ''}" ) fig.update_layout( xaxis=dict( title="SHAP Value (Impact on Model Probability)", showgrid=True, gridcolor='WhiteSmoke', zerolinecolor='Gainsboro' ), yaxis=dict( title="Feature", showgrid=True, gridcolor='WhiteSmoke', zerolinecolor='Gainsboro' ), plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)', height=max(500, len(feature_names) * 40), legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) fig.update_traces(jitter=1, marker=dict(size=12, opacity=0.9, line=dict(width=1, color='DarkSlateGrey'))) # 5. Export to Base64 Image # Requires 'kaleido' package installed img_bytes = fig.to_image(format="png", engine="kaleido", scale=2) return base64.b64encode(img_bytes).decode("utf-8") except Exception as e: print(f"Error generating Plotly SHAP image: {e}") return "" def get_calibrated_feature_importances(model) -> pd.Series: """ Safely extract and aggregate feature importances from a calibrated production model. """ try: if hasattr(model.preprocessor, "get_feature_names_out"): feature_names = model.preprocessor.get_feature_names_out() else: feature_names = model.preprocessor.pipeline.get_feature_names_out() except AttributeError: model._logger.warning("Could not extract feature names. Using generic names.") feature_names = [f"Feature_{i}" for i in range(model.calibrator.n_features_in_)] calibrator = model.calibrator if not hasattr(calibrator, "calibrated_classifiers_"): raise ValueError( "Calibrator is missing 'calibrated_classifiers_'. Is it fitted?" ) importances_list = [] for calibrated_clf in calibrator.calibrated_classifiers_: base_model = getattr( calibrated_clf, "estimator", getattr(calibrated_clf, "base_estimator", None) ) if hasattr(base_model, "feature_importances_"): importances_list.append(base_model.feature_importances_) elif hasattr(base_model, "coef_"): importances_list.append(np.abs(base_model.coef_).mean(axis=0)) else: importances_list.append(np.zeros(len(feature_names))) avg_importances = np.mean(importances_list, axis=0) if len(avg_importances) != len(feature_names): model._logger.warning( f"Shape mismatch: {len(avg_importances)} importances vs {len(feature_names)} names." ) feature_names = [f"Feature_{i}" for i in range(len(avg_importances))] return pd.Series(avg_importances, index=feature_names) def plot_feature_importance_heatmap( model, top_n: int = 30, skip_top: int = 0, title: Optional[str] = None ) -> Tuple[go.Figure, pd.DataFrame]: """ Generate heatmap of the top feature importances with a transparent background. """ importances = get_calibrated_feature_importances(model) top_importances = importances.sort_values(ascending=False).iloc[ skip_top : skip_top + top_n ] max_val = top_importances.max() norm_importances = top_importances / max_val if max_val > 0 else top_importances df_plot = pd.DataFrame( {"Feature": top_importances.index, "Importance": norm_importances.values} ) fig = px.imshow( [df_plot["Importance"].values], labels=dict(x="Model Features", y="", color="Relative Importance"), x=df_plot["Feature"], color_continuous_scale="Reds", text_auto=".2f", aspect="auto", ) display_title = title or f"Top {top_n} Features - {model.model_name}" if skip_top > 0: display_title += f" (Skipping Top {skip_top})" fig.update_layout( title=dict(text=display_title, font=dict(size=18)), height=600, xaxis_tickangle=-45, yaxis=dict(showticklabels=False), template="plotly_white", margin=dict(t=60, b=120), plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", ) return fig, df_plot def plotly_to_base64(fig: go.Figure) -> str: """ Converts a Plotly figure to a base64 encoded PNG string. """ img_bytes = fig.to_image(format="png") return base64.b64encode(img_bytes).decode("utf-8")