Spaces:
Sleeping
Sleeping
File size: 10,858 Bytes
8a08300 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 |
"""
SHAP Explainability Engine.
Implements regulatory-compliant explainability using SHAP (SHapley Additive exPlanations).
Provides both local (per-transaction) and global (model-wide) explanations.
Based on research notebook SHAP implementation:
- TreeExplainer for XGBoost models
- Waterfall plots for local explanations
- Summary plots for global feature importance
"""
import base64
import io
from pathlib import Path
from typing import Dict, Optional, Tuple, Union
import joblib
import matplotlib
matplotlib.use("Agg") # Non-interactive backend for server environments
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.pipeline import Pipeline
class FraudExplainer:
"""
SHAP-based explainability engine for fraud detection model.
Provides transparent, auditable explanations for fraud predictions:
- **Local Explanations**: Why a specific transaction was flagged (waterfall)
- **Global Explanations**: Overall feature importance (summary plot)
Example:
>>> explainer = FraudExplainer("models/fraud_model.pkl")
>>>
>>> # Explain a single transaction
>>> transaction = pd.DataFrame([{...}])
>>> waterfall_b64 = explainer.generate_waterfall(transaction)
>>>
>>> # Global feature importance
>>> summary_b64 = explainer.generate_summary(X_test_sample)
"""
def __init__(self, pipeline_path: str):
"""
Initialize SHAP explainer with trained pipeline.
Args:
pipeline_path: Path to saved pipeline (.pkl file)
Raises:
FileNotFoundError: If pipeline file doesn't exist
ValueError: If pipeline structure is invalid
"""
pipeline_path = Path(pipeline_path)
if not pipeline_path.exists():
raise FileNotFoundError(f"Pipeline not found: {pipeline_path}")
# Load trained pipeline
self.pipeline: Pipeline = joblib.load(pipeline_path)
# Extract components
if "model" not in self.pipeline.named_steps:
raise ValueError("Pipeline must contain 'model' step")
if "preprocessor" not in self.pipeline.named_steps:
raise ValueError("Pipeline must contain 'preprocessor' step")
self.model = self.pipeline.named_steps["model"]
self.preprocessor = self.pipeline.named_steps["preprocessor"]
# Initialize SHAP TreeExplainer
# TreeExplainer is optimized for tree-based models (XGBoost, RandomForest)
self.explainer = shap.TreeExplainer(self.model)
# Get feature names after transformation
self.feature_names = self._get_feature_names()
def _get_feature_names(self) -> list:
"""
Extract feature names from preprocessor.
Returns:
List of feature names after ColumnTransformer
"""
try:
# Try sklearn 1.0+ method
return list(self.preprocessor.get_feature_names_out())
except AttributeError:
# Fallback: Manually construct from transformer configuration
# This matches our pipeline structure:
# cat: ['job', 'category']
# num: ['amt_log', 'age', 'distance_km', 'trans_count_24h', ...]
# binary: ['gender']
# cyclical: ['hour_sin', 'hour_cos', 'day_sin', 'day_cos']
categorical = ["job", "category"]
numerical = [
"amt_log",
"age",
"distance_km",
"trans_count_24h",
"amt_to_avg_ratio_24h",
"amt_relative_to_all_time",
]
binary = ["gender"]
cyclical = ["hour_sin", "hour_cos", "day_sin", "day_cos"]
return categorical + numerical + binary + cyclical
def _transform_data(self, X: pd.DataFrame) -> np.ndarray:
"""
Transform raw transaction data through pipeline preprocessor.
This is the crucial step mentioned in the notebook to resolve
"You have categorical data..." errors.
Args:
X: Raw transaction DataFrame
Returns:
Transformed numerical array ready for SHAP
"""
# Apply feature extraction (if 'features' step exists)
if "features" in self.pipeline.named_steps:
X = self.pipeline.named_steps["features"].transform(X)
# Apply preprocessing (WOE, scaling, passthrough)
X_transformed = self.preprocessor.transform(X)
return X_transformed
def calculate_shap_values(
self, X: pd.DataFrame, transformed: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
"""
Calculate SHAP values for input data.
Args:
X: Transaction data (raw or transformed)
transformed: If True, X is already transformed. If False, transform it.
Returns:
Tuple of (shap_values, transformed_X)
"""
if not transformed:
X_transformed = self._transform_data(X)
else:
X_transformed = X
# Calculate SHAP values
shap_values = self.explainer.shap_values(X_transformed)
return shap_values, X_transformed
def generate_waterfall(
self, transaction: pd.DataFrame, return_base64: bool = True, max_display: int = 10
) -> Union[str, matplotlib.figure.Figure]:
"""
Generate SHAP waterfall plot for a single transaction.
Shows how each feature contributed to pushing the prediction
from the base value (average) to the final prediction.
Args:
transaction: Single transaction DataFrame (1 row)
return_base64: If True, return base64 PNG. If False, return Figure.
max_display: Maximum features to display
Returns:
Base64-encoded PNG string or matplotlib Figure
Example:
>>> waterfall_img = explainer.generate_waterfall(transaction_df)
>>> # Save to file
>>> with open('waterfall.png', 'wb') as f:
... f.write(base64.b64decode(waterfall_img))
"""
if len(transaction) != 1:
raise ValueError(f"Expected 1 transaction, got {len(transaction)}")
# Transform and calculate SHAP
X_transformed = self._transform_data(transaction)
# Create DataFrame with feature names for plotting
X_df = pd.DataFrame(X_transformed, columns=self.feature_names)
# Generate SHAP explanation object
explanation = self.explainer(X_df)
# Create waterfall plot
fig = plt.figure(figsize=(10, 6))
shap.plots.waterfall(explanation[0], max_display=max_display, show=False)
plt.tight_layout()
if return_base64:
img_base64 = self._plot_to_base64(fig)
return img_base64
else:
return fig
def generate_summary(
self, X_sample: pd.DataFrame, return_base64: bool = True, max_display: int = 20
) -> Union[str, matplotlib.figure.Figure]:
"""
Generate SHAP summary plot for global feature importance.
Shows which features are most important across all predictions.
Each dot represents a transaction, color indicates feature value.
Args:
X_sample: Sample of transactions (typically 100-1000 rows)
return_base64: If True, return base64 PNG. If False, return Figure.
max_display: Maximum features to display
Returns:
Base64-encoded PNG string or matplotlib Figure
Example:
>>> # Analyze 500 test transactions
>>> summary_img = explainer.generate_summary(X_test[:500])
"""
# Transform data
X_transformed = self._transform_data(X_sample)
# Calculate SHAP values
shap_values = self.explainer.shap_values(X_transformed)
# Create summary plot
fig = plt.figure(figsize=(10, 8))
shap.summary_plot(
shap_values,
X_transformed,
feature_names=self.feature_names,
max_display=max_display,
show=False,
)
plt.tight_layout()
if return_base64:
img_base64 = self._plot_to_base64(fig)
return img_base64
else:
return fig
def explain_prediction(
self, transaction: pd.DataFrame, threshold: float = 0.5
) -> Dict[str, any]:
"""
Get comprehensive explanation for a single prediction.
Args:
transaction: Single transaction DataFrame
threshold: Decision threshold
Returns:
Dictionary with:
- prediction: fraud probability
- decision: "BLOCK" or "APPROVE"
- shap_values: feature contributions
- top_features: top 5 features sorted by impact
- base_value: model's base prediction (average)
Example:
>>> explanation = explainer.explain_prediction(transaction_df, threshold=0.895)
>>> print(explanation['decision']) # "BLOCK"
>>> print(explanation['top_features'])
[{'feature': 'amt_log', 'impact': 0.32}, ...]
"""
# Get prediction probability
y_prob = self.pipeline.predict_proba(transaction)[0, 1]
# Transform for SHAP
X_transformed = self._transform_data(transaction)
shap_values = self.explainer.shap_values(X_transformed)
# Get base value (expected value)
base_value = self.explainer.expected_value
# Sort features by absolute impact
feature_impacts = [
{"feature": feat, "impact": float(shap_val), "abs_impact": abs(float(shap_val))}
for feat, shap_val in zip(self.feature_names, shap_values[0])
]
feature_impacts.sort(key=lambda x: x["abs_impact"], reverse=True)
return {
"prediction": float(y_prob),
"decision": "BLOCK" if y_prob >= threshold else "APPROVE",
"threshold": threshold,
"shap_values": {
feat: float(val) for feat, val in zip(self.feature_names, shap_values[0])
},
"top_features": feature_impacts[:5],
"base_value": float(base_value),
}
def _plot_to_base64(self, fig: matplotlib.figure.Figure) -> str:
"""
Convert matplotlib figure to base64-encoded PNG.
Args:
fig: Matplotlib figure
Returns:
Base64-encoded PNG string
"""
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", dpi=100)
buf.seek(0)
img_base64 = base64.b64encode(buf.read()).decode("utf-8")
plt.close(fig)
return img_base64
__all__ = ["FraudExplainer"]
|