boq-api / utils /plot_utils.py
gabcares's picture
Upload 80 files
72fdabd verified
Raw
History Blame Contribute Delete
7.01 kB
import base64
import io
import shap
import pandas as pd
import numpy as np
from typing import Optional, Tuple
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
def generate_shap_summary_plot_base64(shap_values, X_proc, feature_names=None, target_class=None) -> str:
"""
Generates a SHAP summary plot using Plotly (Strip Plot) and returns it as a base64 string.
"""
try:
# 1. Prepare Feature Names
if feature_names is None:
if hasattr(X_proc, "columns"):
feature_names = list(X_proc.columns)
else:
feature_names = [f"Feature {i}" for i in range(X_proc.shape[1] if hasattr(X_proc, "shape") else len(X_proc[0]))]
# 2. Handle SHAP values input (ensure it's 1D for single instance or handle multiple)
# ExplainerService passes shap_vals_target which is typically (n_features,) for single prediction
vals = shap_values
if isinstance(vals, list):
vals = vals[1] if len(vals) > 1 else vals[0]
if hasattr(vals, "shape"):
if len(vals.shape) == 2 and vals.shape[0] == 1:
vals = vals[0] # Flatten (1, features) -> (features,)
# 3. Create DataFrame
# If vals is 1D (n,), we treat it as 1 sample.
# px.strip expects a distribution, but for 1 sample it works as dot plot.
df_plot = pd.DataFrame({
"Feature": feature_names,
"SHAP": vals
})
# Add coloring based on impact direction (Risk/Protective)
df_plot["Type"] = ["Risk (Positive)" if v > 0 else "Protective (Negative)" for v in vals]
# Sort features by absolute SHAP value (Importance)
df_plot["AbsSHAP"] = df_plot["SHAP"].abs()
df_plot = df_plot.sort_values("AbsSHAP", ascending=True) # Ascending for correct Y-axis order in Plotly
# 4. Generate Plotly Strip Plot
fig = px.strip(
df_plot,
x='SHAP',
y='Feature',
color='Type',
stripmode='overlay',
color_discrete_map={
"Risk (Positive)": "#ef4444",
"Protective (Negative)": "#10b981"
},
title=f"SHAP Impact Analysis{f' (Predicted: {target_class})' if target_class else ''}"
)
fig.update_layout(
xaxis=dict(
title="SHAP Value (Impact on Model Probability)",
showgrid=True,
gridcolor='WhiteSmoke',
zerolinecolor='Gainsboro'
),
yaxis=dict(
title="Feature",
showgrid=True,
gridcolor='WhiteSmoke',
zerolinecolor='Gainsboro'
),
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)',
height=max(500, len(feature_names) * 40),
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
fig.update_traces(jitter=1, marker=dict(size=12, opacity=0.9, line=dict(width=1, color='DarkSlateGrey')))
# 5. Export to Base64 Image
# Requires 'kaleido' package installed
img_bytes = fig.to_image(format="png", engine="kaleido", scale=2)
return base64.b64encode(img_bytes).decode("utf-8")
except Exception as e:
print(f"Error generating Plotly SHAP image: {e}")
return ""
def get_calibrated_feature_importances(model) -> pd.Series:
"""
Safely extract and aggregate feature importances from a calibrated production model.
"""
try:
if hasattr(model.preprocessor, "get_feature_names_out"):
feature_names = model.preprocessor.get_feature_names_out()
else:
feature_names = model.preprocessor.pipeline.get_feature_names_out()
except AttributeError:
model._logger.warning("Could not extract feature names. Using generic names.")
feature_names = [f"Feature_{i}" for i in range(model.calibrator.n_features_in_)]
calibrator = model.calibrator
if not hasattr(calibrator, "calibrated_classifiers_"):
raise ValueError(
"Calibrator is missing 'calibrated_classifiers_'. Is it fitted?"
)
importances_list = []
for calibrated_clf in calibrator.calibrated_classifiers_:
base_model = getattr(
calibrated_clf, "estimator", getattr(calibrated_clf, "base_estimator", None)
)
if hasattr(base_model, "feature_importances_"):
importances_list.append(base_model.feature_importances_)
elif hasattr(base_model, "coef_"):
importances_list.append(np.abs(base_model.coef_).mean(axis=0))
else:
importances_list.append(np.zeros(len(feature_names)))
avg_importances = np.mean(importances_list, axis=0)
if len(avg_importances) != len(feature_names):
model._logger.warning(
f"Shape mismatch: {len(avg_importances)} importances vs {len(feature_names)} names."
)
feature_names = [f"Feature_{i}" for i in range(len(avg_importances))]
return pd.Series(avg_importances, index=feature_names)
def plot_feature_importance_heatmap(
model, top_n: int = 30, skip_top: int = 0, title: Optional[str] = None
) -> Tuple[go.Figure, pd.DataFrame]:
"""
Generate heatmap of the top feature importances with a transparent background.
"""
importances = get_calibrated_feature_importances(model)
top_importances = importances.sort_values(ascending=False).iloc[
skip_top : skip_top + top_n
]
max_val = top_importances.max()
norm_importances = top_importances / max_val if max_val > 0 else top_importances
df_plot = pd.DataFrame(
{"Feature": top_importances.index, "Importance": norm_importances.values}
)
fig = px.imshow(
[df_plot["Importance"].values],
labels=dict(x="Model Features", y="", color="Relative Importance"),
x=df_plot["Feature"],
color_continuous_scale="Reds",
text_auto=".2f",
aspect="auto",
)
display_title = title or f"Top {top_n} Features - {model.model_name}"
if skip_top > 0:
display_title += f" (Skipping Top {skip_top})"
fig.update_layout(
title=dict(text=display_title, font=dict(size=18)),
height=600,
xaxis_tickangle=-45,
yaxis=dict(showticklabels=False),
template="plotly_white",
margin=dict(t=60, b=120),
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
return fig, df_plot
def plotly_to_base64(fig: go.Figure) -> str:
"""
Converts a Plotly figure to a base64 encoded PNG string.
"""
img_bytes = fig.to_image(format="png")
return base64.b64encode(img_bytes).decode("utf-8")