| import base64 |
| import io |
| import shap |
| import pandas as pd |
| import numpy as np |
| from typing import Optional, Tuple |
| import plotly.express as px |
| import plotly.graph_objects as go |
| import plotly.io as pio |
|
|
| def generate_shap_summary_plot_base64(shap_values, X_proc, feature_names=None, target_class=None) -> str: |
| """ |
| Generates a SHAP summary plot using Plotly (Strip Plot) and returns it as a base64 string. |
| """ |
| try: |
| |
| if feature_names is None: |
| if hasattr(X_proc, "columns"): |
| feature_names = list(X_proc.columns) |
| else: |
| feature_names = [f"Feature {i}" for i in range(X_proc.shape[1] if hasattr(X_proc, "shape") else len(X_proc[0]))] |
| |
| |
| |
| vals = shap_values |
| if isinstance(vals, list): |
| vals = vals[1] if len(vals) > 1 else vals[0] |
| if hasattr(vals, "shape"): |
| if len(vals.shape) == 2 and vals.shape[0] == 1: |
| vals = vals[0] |
| |
| |
| |
| |
| |
| df_plot = pd.DataFrame({ |
| "Feature": feature_names, |
| "SHAP": vals |
| }) |
| |
| |
| df_plot["Type"] = ["Risk (Positive)" if v > 0 else "Protective (Negative)" for v in vals] |
| |
| |
| df_plot["AbsSHAP"] = df_plot["SHAP"].abs() |
| df_plot = df_plot.sort_values("AbsSHAP", ascending=True) |
| |
| |
| fig = px.strip( |
| df_plot, |
| x='SHAP', |
| y='Feature', |
| color='Type', |
| stripmode='overlay', |
| color_discrete_map={ |
| "Risk (Positive)": "#ef4444", |
| "Protective (Negative)": "#10b981" |
| }, |
| title=f"SHAP Impact Analysis{f' (Predicted: {target_class})' if target_class else ''}" |
| ) |
| |
| fig.update_layout( |
| xaxis=dict( |
| title="SHAP Value (Impact on Model Probability)", |
| showgrid=True, |
| gridcolor='WhiteSmoke', |
| zerolinecolor='Gainsboro' |
| ), |
| yaxis=dict( |
| title="Feature", |
| showgrid=True, |
| gridcolor='WhiteSmoke', |
| zerolinecolor='Gainsboro' |
| ), |
| plot_bgcolor='rgba(0,0,0,0)', |
| paper_bgcolor='rgba(0,0,0,0)', |
| height=max(500, len(feature_names) * 40), |
| legend=dict( |
| orientation="h", |
| yanchor="bottom", |
| y=1.02, |
| xanchor="right", |
| x=1 |
| ) |
| ) |
| |
| fig.update_traces(jitter=1, marker=dict(size=12, opacity=0.9, line=dict(width=1, color='DarkSlateGrey'))) |
|
|
| |
| |
| img_bytes = fig.to_image(format="png", engine="kaleido", scale=2) |
| return base64.b64encode(img_bytes).decode("utf-8") |
| |
| except Exception as e: |
| print(f"Error generating Plotly SHAP image: {e}") |
| return "" |
|
|
|
|
| def get_calibrated_feature_importances(model) -> pd.Series: |
| """ |
| Safely extract and aggregate feature importances from a calibrated production model. |
| """ |
| try: |
| if hasattr(model.preprocessor, "get_feature_names_out"): |
| feature_names = model.preprocessor.get_feature_names_out() |
| else: |
| feature_names = model.preprocessor.pipeline.get_feature_names_out() |
| except AttributeError: |
| model._logger.warning("Could not extract feature names. Using generic names.") |
| feature_names = [f"Feature_{i}" for i in range(model.calibrator.n_features_in_)] |
|
|
| calibrator = model.calibrator |
| if not hasattr(calibrator, "calibrated_classifiers_"): |
| raise ValueError( |
| "Calibrator is missing 'calibrated_classifiers_'. Is it fitted?" |
| ) |
|
|
| importances_list = [] |
|
|
| for calibrated_clf in calibrator.calibrated_classifiers_: |
| base_model = getattr( |
| calibrated_clf, "estimator", getattr(calibrated_clf, "base_estimator", None) |
| ) |
|
|
| if hasattr(base_model, "feature_importances_"): |
| importances_list.append(base_model.feature_importances_) |
| elif hasattr(base_model, "coef_"): |
| importances_list.append(np.abs(base_model.coef_).mean(axis=0)) |
| else: |
| importances_list.append(np.zeros(len(feature_names))) |
|
|
| avg_importances = np.mean(importances_list, axis=0) |
|
|
| if len(avg_importances) != len(feature_names): |
| model._logger.warning( |
| f"Shape mismatch: {len(avg_importances)} importances vs {len(feature_names)} names." |
| ) |
| feature_names = [f"Feature_{i}" for i in range(len(avg_importances))] |
|
|
| return pd.Series(avg_importances, index=feature_names) |
|
|
|
|
| def plot_feature_importance_heatmap( |
| model, top_n: int = 30, skip_top: int = 0, title: Optional[str] = None |
| ) -> Tuple[go.Figure, pd.DataFrame]: |
| """ |
| Generate heatmap of the top feature importances with a transparent background. |
| """ |
| importances = get_calibrated_feature_importances(model) |
|
|
| top_importances = importances.sort_values(ascending=False).iloc[ |
| skip_top : skip_top + top_n |
| ] |
|
|
| max_val = top_importances.max() |
| norm_importances = top_importances / max_val if max_val > 0 else top_importances |
|
|
| df_plot = pd.DataFrame( |
| {"Feature": top_importances.index, "Importance": norm_importances.values} |
| ) |
|
|
| fig = px.imshow( |
| [df_plot["Importance"].values], |
| labels=dict(x="Model Features", y="", color="Relative Importance"), |
| x=df_plot["Feature"], |
| color_continuous_scale="Reds", |
| text_auto=".2f", |
| aspect="auto", |
| ) |
|
|
| display_title = title or f"Top {top_n} Features - {model.model_name}" |
| if skip_top > 0: |
| display_title += f" (Skipping Top {skip_top})" |
|
|
| fig.update_layout( |
| title=dict(text=display_title, font=dict(size=18)), |
| height=600, |
| xaxis_tickangle=-45, |
| yaxis=dict(showticklabels=False), |
| template="plotly_white", |
| margin=dict(t=60, b=120), |
| plot_bgcolor="rgba(0,0,0,0)", |
| paper_bgcolor="rgba(0,0,0,0)", |
| ) |
|
|
| return fig, df_plot |
|
|
|
|
| def plotly_to_base64(fig: go.Figure) -> str: |
| """ |
| Converts a Plotly figure to a base64 encoded PNG string. |
| """ |
| img_bytes = fig.to_image(format="png") |
| return base64.b64encode(img_bytes).decode("utf-8") |
|
|