File size: 10,858 Bytes
8a08300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
SHAP Explainability Engine.

Implements regulatory-compliant explainability using SHAP (SHapley Additive exPlanations).
Provides both local (per-transaction) and global (model-wide) explanations.

Based on research notebook SHAP implementation:
- TreeExplainer for XGBoost models
- Waterfall plots for local explanations
- Summary plots for global feature importance
"""

import base64
import io
from pathlib import Path
from typing import Dict, Optional, Tuple, Union

import joblib
import matplotlib

matplotlib.use("Agg")  # Non-interactive backend for server environments
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shap
from sklearn.pipeline import Pipeline


class FraudExplainer:
    """
    SHAP-based explainability engine for fraud detection model.

    Provides transparent, auditable explanations for fraud predictions:
    - **Local Explanations**: Why a specific transaction was flagged (waterfall)
    - **Global Explanations**: Overall feature importance (summary plot)

    Example:
        >>> explainer = FraudExplainer("models/fraud_model.pkl")
        >>>
        >>> # Explain a single transaction
        >>> transaction = pd.DataFrame([{...}])
        >>> waterfall_b64 = explainer.generate_waterfall(transaction)
        >>>
        >>> # Global feature importance
        >>> summary_b64 = explainer.generate_summary(X_test_sample)
    """

    def __init__(self, pipeline_path: str):
        """
        Initialize SHAP explainer with trained pipeline.

        Args:
            pipeline_path: Path to saved pipeline (.pkl file)

        Raises:
            FileNotFoundError: If pipeline file doesn't exist
            ValueError: If pipeline structure is invalid
        """
        pipeline_path = Path(pipeline_path)
        if not pipeline_path.exists():
            raise FileNotFoundError(f"Pipeline not found: {pipeline_path}")

        # Load trained pipeline
        self.pipeline: Pipeline = joblib.load(pipeline_path)

        # Extract components
        if "model" not in self.pipeline.named_steps:
            raise ValueError("Pipeline must contain 'model' step")
        if "preprocessor" not in self.pipeline.named_steps:
            raise ValueError("Pipeline must contain 'preprocessor' step")

        self.model = self.pipeline.named_steps["model"]
        self.preprocessor = self.pipeline.named_steps["preprocessor"]

        # Initialize SHAP TreeExplainer
        # TreeExplainer is optimized for tree-based models (XGBoost, RandomForest)
        self.explainer = shap.TreeExplainer(self.model)

        # Get feature names after transformation
        self.feature_names = self._get_feature_names()

    def _get_feature_names(self) -> list:
        """
        Extract feature names from preprocessor.

        Returns:
            List of feature names after ColumnTransformer
        """
        try:
            # Try sklearn 1.0+ method
            return list(self.preprocessor.get_feature_names_out())
        except AttributeError:
            # Fallback: Manually construct from transformer configuration
            # This matches our pipeline structure:
            # cat: ['job', 'category']
            # num: ['amt_log', 'age', 'distance_km', 'trans_count_24h', ...]
            # binary: ['gender']
            # cyclical: ['hour_sin', 'hour_cos', 'day_sin', 'day_cos']
            categorical = ["job", "category"]
            numerical = [
                "amt_log",
                "age",
                "distance_km",
                "trans_count_24h",
                "amt_to_avg_ratio_24h",
                "amt_relative_to_all_time",
            ]
            binary = ["gender"]
            cyclical = ["hour_sin", "hour_cos", "day_sin", "day_cos"]
            return categorical + numerical + binary + cyclical

    def _transform_data(self, X: pd.DataFrame) -> np.ndarray:
        """
        Transform raw transaction data through pipeline preprocessor.

        This is the crucial step mentioned in the notebook to resolve
        "You have categorical data..." errors.

        Args:
            X: Raw transaction DataFrame

        Returns:
            Transformed numerical array ready for SHAP
        """
        # Apply feature extraction (if 'features' step exists)
        if "features" in self.pipeline.named_steps:
            X = self.pipeline.named_steps["features"].transform(X)

        # Apply preprocessing (WOE, scaling, passthrough)
        X_transformed = self.preprocessor.transform(X)

        return X_transformed

    def calculate_shap_values(
        self, X: pd.DataFrame, transformed: bool = False
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Calculate SHAP values for input data.

        Args:
            X: Transaction data (raw or transformed)
            transformed: If True, X is already transformed. If False, transform it.

        Returns:
            Tuple of (shap_values, transformed_X)
        """
        if not transformed:
            X_transformed = self._transform_data(X)
        else:
            X_transformed = X

        # Calculate SHAP values
        shap_values = self.explainer.shap_values(X_transformed)

        return shap_values, X_transformed

    def generate_waterfall(
        self, transaction: pd.DataFrame, return_base64: bool = True, max_display: int = 10
    ) -> Union[str, matplotlib.figure.Figure]:
        """
        Generate SHAP waterfall plot for a single transaction.

        Shows how each feature contributed to pushing the prediction
        from the base value (average) to the final prediction.

        Args:
            transaction: Single transaction DataFrame (1 row)
            return_base64: If True, return base64 PNG. If False, return Figure.
            max_display: Maximum features to display

        Returns:
            Base64-encoded PNG string or matplotlib Figure

        Example:
            >>> waterfall_img = explainer.generate_waterfall(transaction_df)
            >>> # Save to file
            >>> with open('waterfall.png', 'wb') as f:
            ...     f.write(base64.b64decode(waterfall_img))
        """
        if len(transaction) != 1:
            raise ValueError(f"Expected 1 transaction, got {len(transaction)}")

        # Transform and calculate SHAP
        X_transformed = self._transform_data(transaction)

        # Create DataFrame with feature names for plotting
        X_df = pd.DataFrame(X_transformed, columns=self.feature_names)

        # Generate SHAP explanation object
        explanation = self.explainer(X_df)

        # Create waterfall plot
        fig = plt.figure(figsize=(10, 6))
        shap.plots.waterfall(explanation[0], max_display=max_display, show=False)
        plt.tight_layout()

        if return_base64:
            img_base64 = self._plot_to_base64(fig)
            return img_base64
        else:
            return fig

    def generate_summary(
        self, X_sample: pd.DataFrame, return_base64: bool = True, max_display: int = 20
    ) -> Union[str, matplotlib.figure.Figure]:
        """
        Generate SHAP summary plot for global feature importance.

        Shows which features are most important across all predictions.
        Each dot represents a transaction, color indicates feature value.

        Args:
            X_sample: Sample of transactions (typically 100-1000 rows)
            return_base64: If True, return base64 PNG. If False, return Figure.
            max_display: Maximum features to display

        Returns:
            Base64-encoded PNG string or matplotlib Figure

        Example:
            >>> # Analyze 500 test transactions
            >>> summary_img = explainer.generate_summary(X_test[:500])
        """
        # Transform data
        X_transformed = self._transform_data(X_sample)

        # Calculate SHAP values
        shap_values = self.explainer.shap_values(X_transformed)

        # Create summary plot
        fig = plt.figure(figsize=(10, 8))
        shap.summary_plot(
            shap_values,
            X_transformed,
            feature_names=self.feature_names,
            max_display=max_display,
            show=False,
        )
        plt.tight_layout()

        if return_base64:
            img_base64 = self._plot_to_base64(fig)
            return img_base64
        else:
            return fig

    def explain_prediction(
        self, transaction: pd.DataFrame, threshold: float = 0.5
    ) -> Dict[str, any]:
        """
        Get comprehensive explanation for a single prediction.

        Args:
            transaction: Single transaction DataFrame
            threshold: Decision threshold

        Returns:
            Dictionary with:
            - prediction: fraud probability
            - decision: "BLOCK" or "APPROVE"
            - shap_values: feature contributions
            - top_features: top 5 features sorted by impact
            - base_value: model's base prediction (average)

        Example:
            >>> explanation = explainer.explain_prediction(transaction_df, threshold=0.895)
            >>> print(explanation['decision'])  # "BLOCK"
            >>> print(explanation['top_features'])
            [{'feature': 'amt_log', 'impact': 0.32}, ...]
        """
        # Get prediction probability
        y_prob = self.pipeline.predict_proba(transaction)[0, 1]

        # Transform for SHAP
        X_transformed = self._transform_data(transaction)
        shap_values = self.explainer.shap_values(X_transformed)

        # Get base value (expected value)
        base_value = self.explainer.expected_value

        # Sort features by absolute impact
        feature_impacts = [
            {"feature": feat, "impact": float(shap_val), "abs_impact": abs(float(shap_val))}
            for feat, shap_val in zip(self.feature_names, shap_values[0])
        ]
        feature_impacts.sort(key=lambda x: x["abs_impact"], reverse=True)

        return {
            "prediction": float(y_prob),
            "decision": "BLOCK" if y_prob >= threshold else "APPROVE",
            "threshold": threshold,
            "shap_values": {
                feat: float(val) for feat, val in zip(self.feature_names, shap_values[0])
            },
            "top_features": feature_impacts[:5],
            "base_value": float(base_value),
        }

    def _plot_to_base64(self, fig: matplotlib.figure.Figure) -> str:
        """
        Convert matplotlib figure to base64-encoded PNG.

        Args:
            fig: Matplotlib figure

        Returns:
            Base64-encoded PNG string
        """
        buf = io.BytesIO()
        fig.savefig(buf, format="png", bbox_inches="tight", dpi=100)
        buf.seek(0)
        img_base64 = base64.b64encode(buf.read()).decode("utf-8")
        plt.close(fig)
        return img_base64


__all__ = ["FraudExplainer"]