PayShield-ML / src /explainability.py
Sibi Krishnamoorthy
prod
8a08300
"""
SHAP Explainability Engine.
Implements regulatory-compliant explainability using SHAP (SHapley Additive exPlanations).
Provides both local (per-transaction) and global (model-wide) explanations.
Based on research notebook SHAP implementation:
- TreeExplainer for XGBoost models
- Waterfall plots for local explanations
- Summary plots for global feature importance
"""
import base64
import io
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import joblib
import matplotlib
matplotlib.use("Agg") # Non-interactive backend for server environments
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.pipeline import Pipeline
class FraudExplainer:
"""
SHAP-based explainability engine for fraud detection model.
Provides transparent, auditable explanations for fraud predictions:
- **Local Explanations**: Why a specific transaction was flagged (waterfall)
- **Global Explanations**: Overall feature importance (summary plot)
Example:
>>> explainer = FraudExplainer("models/fraud_model.pkl")
>>>
>>> # Explain a single transaction
>>> transaction = pd.DataFrame([{...}])
>>> waterfall_b64 = explainer.generate_waterfall(transaction)
>>>
>>> # Global feature importance
>>> summary_b64 = explainer.generate_summary(X_test_sample)
"""
def __init__(self, pipeline_path: str):
"""
Initialize SHAP explainer with trained pipeline.
Args:
pipeline_path: Path to saved pipeline (.pkl file)
Raises:
FileNotFoundError: If pipeline file doesn't exist
ValueError: If pipeline structure is invalid
"""
pipeline_path = Path(pipeline_path)
if not pipeline_path.exists():
raise FileNotFoundError(f"Pipeline not found: {pipeline_path}")
# Load trained pipeline
self.pipeline: Pipeline = joblib.load(pipeline_path)
# Extract components
if "model" not in self.pipeline.named_steps:
raise ValueError("Pipeline must contain 'model' step")
if "preprocessor" not in self.pipeline.named_steps:
raise ValueError("Pipeline must contain 'preprocessor' step")
self.model = self.pipeline.named_steps["model"]
self.preprocessor = self.pipeline.named_steps["preprocessor"]
# Initialize SHAP TreeExplainer
# TreeExplainer is optimized for tree-based models (XGBoost, RandomForest)
self.explainer = shap.TreeExplainer(self.model)
# Get feature names after transformation
self.feature_names = self._get_feature_names()
def _get_feature_names(self) -> list:
"""
Extract feature names from preprocessor.
Returns:
List of feature names after ColumnTransformer
"""
try:
# Try sklearn 1.0+ method
return list(self.preprocessor.get_feature_names_out())
except AttributeError:
# Fallback: Manually construct from transformer configuration
# This matches our pipeline structure:
# cat: ['job', 'category']
# num: ['amt_log', 'age', 'distance_km', 'trans_count_24h', ...]
# binary: ['gender']
# cyclical: ['hour_sin', 'hour_cos', 'day_sin', 'day_cos']
categorical = ["job", "category"]
numerical = [
"amt_log",
"age",
"distance_km",
"trans_count_24h",
"amt_to_avg_ratio_24h",
"amt_relative_to_all_time",
]
binary = ["gender"]
cyclical = ["hour_sin", "hour_cos", "day_sin", "day_cos"]
return categorical + numerical + binary + cyclical
def _transform_data(self, X: pd.DataFrame) -> np.ndarray:
"""
Transform raw transaction data through pipeline preprocessor.
This is the crucial step mentioned in the notebook to resolve
"You have categorical data..." errors.
Args:
X: Raw transaction DataFrame
Returns:
Transformed numerical array ready for SHAP
"""
# Apply feature extraction (if 'features' step exists)
if "features" in self.pipeline.named_steps:
X = self.pipeline.named_steps["features"].transform(X)
# Apply preprocessing (WOE, scaling, passthrough)
X_transformed = self.preprocessor.transform(X)
return X_transformed
def calculate_shap_values(
self, X: pd.DataFrame, transformed: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculate SHAP values for input data.
Args:
X: Transaction data (raw or transformed)
transformed: If True, X is already transformed. If False, transform it.
Returns:
Tuple of (shap_values, transformed_X)
"""
if not transformed:
X_transformed = self._transform_data(X)
else:
X_transformed = X
# Calculate SHAP values
shap_values = self.explainer.shap_values(X_transformed)
return shap_values, X_transformed
def generate_waterfall(
self, transaction: pd.DataFrame, return_base64: bool = True, max_display: int = 10
) -> Union[str, matplotlib.figure.Figure]:
"""
Generate SHAP waterfall plot for a single transaction.
Shows how each feature contributed to pushing the prediction
from the base value (average) to the final prediction.
Args:
transaction: Single transaction DataFrame (1 row)
return_base64: If True, return base64 PNG. If False, return Figure.
max_display: Maximum features to display
Returns:
Base64-encoded PNG string or matplotlib Figure
Example:
>>> waterfall_img = explainer.generate_waterfall(transaction_df)
>>> # Save to file
>>> with open('waterfall.png', 'wb') as f:
... f.write(base64.b64decode(waterfall_img))
"""
if len(transaction) != 1:
raise ValueError(f"Expected 1 transaction, got {len(transaction)}")
# Transform and calculate SHAP
X_transformed = self._transform_data(transaction)
# Create DataFrame with feature names for plotting
X_df = pd.DataFrame(X_transformed, columns=self.feature_names)
# Generate SHAP explanation object
explanation = self.explainer(X_df)
# Create waterfall plot
fig = plt.figure(figsize=(10, 6))
shap.plots.waterfall(explanation[0], max_display=max_display, show=False)
plt.tight_layout()
if return_base64:
img_base64 = self._plot_to_base64(fig)
return img_base64
else:
return fig
def generate_summary(
self, X_sample: pd.DataFrame, return_base64: bool = True, max_display: int = 20
) -> Union[str, matplotlib.figure.Figure]:
"""
Generate SHAP summary plot for global feature importance.
Shows which features are most important across all predictions.
Each dot represents a transaction, color indicates feature value.
Args:
X_sample: Sample of transactions (typically 100-1000 rows)
return_base64: If True, return base64 PNG. If False, return Figure.
max_display: Maximum features to display
Returns:
Base64-encoded PNG string or matplotlib Figure
Example:
>>> # Analyze 500 test transactions
>>> summary_img = explainer.generate_summary(X_test[:500])
"""
# Transform data
X_transformed = self._transform_data(X_sample)
# Calculate SHAP values
shap_values = self.explainer.shap_values(X_transformed)
# Create summary plot
fig = plt.figure(figsize=(10, 8))
shap.summary_plot(
shap_values,
X_transformed,
feature_names=self.feature_names,
max_display=max_display,
show=False,
)
plt.tight_layout()
if return_base64:
img_base64 = self._plot_to_base64(fig)
return img_base64
else:
return fig
def explain_prediction(
self, transaction: pd.DataFrame, threshold: float = 0.5
) -> Dict[str, any]:
"""
Get comprehensive explanation for a single prediction.
Args:
transaction: Single transaction DataFrame
threshold: Decision threshold
Returns:
Dictionary with:
- prediction: fraud probability
- decision: "BLOCK" or "APPROVE"
- shap_values: feature contributions
- top_features: top 5 features sorted by impact
- base_value: model's base prediction (average)
Example:
>>> explanation = explainer.explain_prediction(transaction_df, threshold=0.895)
>>> print(explanation['decision']) # "BLOCK"
>>> print(explanation['top_features'])
[{'feature': 'amt_log', 'impact': 0.32}, ...]
"""
# Get prediction probability
y_prob = self.pipeline.predict_proba(transaction)[0, 1]
# Transform for SHAP
X_transformed = self._transform_data(transaction)
shap_values = self.explainer.shap_values(X_transformed)
# Get base value (expected value)
base_value = self.explainer.expected_value
# Sort features by absolute impact
feature_impacts = [
{"feature": feat, "impact": float(shap_val), "abs_impact": abs(float(shap_val))}
for feat, shap_val in zip(self.feature_names, shap_values[0])
]
feature_impacts.sort(key=lambda x: x["abs_impact"], reverse=True)
return {
"prediction": float(y_prob),
"decision": "BLOCK" if y_prob >= threshold else "APPROVE",
"threshold": threshold,
"shap_values": {
feat: float(val) for feat, val in zip(self.feature_names, shap_values[0])
},
"top_features": feature_impacts[:5],
"base_value": float(base_value),
}
def _plot_to_base64(self, fig: matplotlib.figure.Figure) -> str:
"""
Convert matplotlib figure to base64-encoded PNG.
Args:
fig: Matplotlib figure
Returns:
Base64-encoded PNG string
"""
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=100)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close(fig)
return img_base64
__all__ = ["FraudExplainer"]