boq-api / services /explainer.py
gabcares's picture
Upload 80 files
72fdabd verified
Raw
History Blame Contribute Delete
5.36 kB
import logging
import shap
import numpy as np
import pandas as pd
from datetime import datetime
from models.artifacts_loader import get_artifacts
from schemas.explain import ExplanationResponse, SHAPFeature, FeatureImportanceResponse
from utils.plot_utils import (
generate_shap_summary_plot_base64,
plot_feature_importance_heatmap,
plotly_to_base64,
)
logger = logging.getLogger(__name__)
class ExplainerService:
@staticmethod
def explain_shap(
item: dict, top_n: int = 10, plot: bool = False
) -> ExplanationResponse:
artifacts = get_artifacts()
if not artifacts.model:
raise RuntimeError("Model not loaded")
try:
# Prepare Input
model = artifacts.model
df = model._to_dataframe([item])
df_norm = model._normalize_boq_schema(df)
# The background data for SHAP is expected to be transformed data.
# Explaining the inner classifier using transformed input.
preprocessor = model.preprocessor
calibrator = model.calibrator
X_proc = preprocessor.transform(df_norm)
# Predict to find the target class index
try:
pred_label = model.predict(df)[0]
target_class_index = list(calibrator.classes_).index(pred_label)
except Exception:
pred_label = "Unknown"
target_class_index = 0
# 2. Compute SHAP
# Using KernelExplainer because CalibratedClassifierCV doesn't support TreeExplainer natively easily
if artifacts.shap_background is None:
raise RuntimeError(
"SHAP background data is missing but required for KernelExplainer"
)
explainer = shap.KernelExplainer(
calibrator.predict_proba, artifacts.shap_background
)
shap_values_all = explainer.shap_values(X_proc, nsamples=100)
# Robustly extract values for the target class
if isinstance(shap_values_all, list):
if len(shap_values_all) > target_class_index:
shap_vals_target = shap_values_all[target_class_index]
else:
shap_vals_target = shap_values_all[0]
elif isinstance(shap_values_all, np.ndarray):
if len(shap_values_all.shape) == 3 and shap_values_all.shape[2] > target_class_index:
shap_vals_target = shap_values_all[:, :, target_class_index]
else:
shap_vals_target = shap_values_all
else:
shap_vals_target = shap_values_all
if len(shap_vals_target.shape) == 2:
sv = shap_vals_target[0]
else:
sv = shap_vals_target
# 3. Features Names
try:
if hasattr(preprocessor, "get_feature_names_out"):
feature_names = preprocessor.get_feature_names_out()
else:
feature_names = preprocessor.pipeline.get_feature_names_out()
except Exception:
feature_names = [f"fet_{i}" for i in range(len(sv))]
# 4. Top Features
abs_importance = np.abs(sv)
top_indices = np.argsort(abs_importance)[-top_n:][::-1]
top_features = [
SHAPFeature(feature=str(feature_names[i]), value=float(sv[i]))
for i in top_indices
]
# 5. Plot
plot_b64 = None
if plot:
plot_b64 = generate_shap_summary_plot_base64(
shap_vals_target, X_proc, feature_names=feature_names, target_class=pred_label
)
return ExplanationResponse(
top_features=top_features,
shap_plot_base64=plot_b64,
metadata={
"method": "SHAP KernelExplainer",
"timestamp": datetime.utcnow().isoformat(),
},
)
except Exception as e:
logger.exception("Error during SHAP explanation")
raise RuntimeError(f"Explanation failed: {e}")
@staticmethod
def explain_feature_importance(
top_n: int = 30, skip_top: int = 0
) -> FeatureImportanceResponse:
artifacts = get_artifacts()
if not artifacts.model:
raise RuntimeError("Model not loaded")
try:
fig, df_plot = plot_feature_importance_heatmap(
model=artifacts.model,
top_n=top_n,
skip_top=skip_top,
title="Global Feature Importance",
)
plot_base64 = plotly_to_base64(fig)
return FeatureImportanceResponse(
plot_base64=plot_base64,
metadata={
"method": "Averaged Calibrated Feature Importances",
"timestamp": datetime.utcnow().isoformat(),
"top_n": top_n,
"skip_top": skip_top,
},
)
except Exception as e:
logger.exception("Error during feature importance generation")
raise RuntimeError(f"Feature importance generation failed: {e}")