Spaces:
Sleeping
Sleeping
| """ | |
| PhishLens SHAP + LIME Explainability Module. | |
| Provides dual-method explainability for all PhishLens classifiers: | |
| - SHAP (SHapley Additive exPlanations) via TreeExplainer (tree models) or | |
| LinearExplainer (Logistic Regression) | |
| - LIME (Local Interpretable Model-agnostic Explanations) as independent check | |
| An "agreement score" quantifies how well SHAP and LIME agree on the top | |
| contributing features — high agreement increases analyst trust in the explanation. | |
| Security rationale: ML model explanations are a critical part of any security | |
| tool deployed in a real SOC environment. Analysts need to understand WHY an | |
| email was flagged to: | |
| 1. Verify the flag is correct (not a false positive) | |
| 2. Document the phishing indicators for incident reports | |
| 3. Feed intelligence back into detection rules | |
| 4. Avoid over-trusting black-box predictions | |
| Reference: | |
| - Lundberg, S.M. & Lee, S.I. (2017). A unified approach to interpreting model predictions. | |
| NeurIPS 2017. https://arxiv.org/abs/1705.07874 | |
| - Ribeiro, M.T., Singh, S., & Guestrin, C. (2016). "Why should I trust you?" | |
| ICML 2016. https://arxiv.org/abs/1602.04938 | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import shap | |
| from lime.lime_tabular import LimeTabularExplainer | |
| from sklearn.linear_model import LogisticRegression | |
| from src.utils.logger import get_logger | |
| log = get_logger(__name__) | |
| class PhishExplainer: | |
| """Dual SHAP + LIME explainer for PhishLens classifiers. | |
| Args: | |
| model: Fitted classifier (must have predict_proba). | |
| feature_names: List of feature names corresponding to X columns. | |
| X_train: Training data (required for LIME background and TreeExplainer). | |
| model_type: One of 'tree', 'linear', 'generic'. Auto-detected if None. | |
| """ | |
| def __init__( | |
| self, | |
| model: Any, | |
| feature_names: List[str], | |
| X_train: np.ndarray, | |
| model_type: Optional[str] = None, | |
| ) -> None: | |
| self.model = model | |
| self.feature_names = feature_names | |
| self.X_train = np.nan_to_num(X_train.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0) | |
| self.model_type = model_type or self._detect_model_type(model) | |
| self._shap_explainer = None | |
| self._lime_explainer = None | |
| self._init_explainers() | |
| def _detect_model_type(self, model: Any) -> str: | |
| """Detect model type for choosing the best SHAP explainer.""" | |
| model_class = type(model).__name__.lower() | |
| if any(t in model_class for t in ["xgb", "lgbm", "catboost", "randomforest", "gradientboosting"]): | |
| return "tree" | |
| if "logistic" in model_class or "linear" in model_class: | |
| return "linear" | |
| return "generic" | |
| def _init_explainers(self) -> None: | |
| """Initialise SHAP and LIME explainers.""" | |
| log.info(f"Initialising {self.model_type} SHAP explainer ...") | |
| try: | |
| if self.model_type == "tree": | |
| self._shap_explainer = shap.TreeExplainer(self.model) | |
| elif self.model_type == "linear": | |
| self._shap_explainer = shap.LinearExplainer( | |
| self.model, self.X_train, feature_names=self.feature_names | |
| ) | |
| else: | |
| # Fallback: KernelExplainer with 100-sample background (slow but universal) | |
| background = shap.sample(self.X_train, 100) | |
| self._shap_explainer = shap.KernelExplainer( | |
| self.model.predict_proba, background | |
| ) | |
| log.info("SHAP explainer ready.") | |
| except Exception as exc: | |
| log.warning(f"SHAP initialisation failed: {exc}") | |
| try: | |
| self._lime_explainer = LimeTabularExplainer( | |
| training_data=self.X_train, | |
| feature_names=self.feature_names, | |
| class_names=["Legitimate", "Phishing"], | |
| mode="classification", | |
| discretize_continuous=True, | |
| random_state=42, | |
| ) | |
| log.info("LIME explainer ready.") | |
| except Exception as exc: | |
| log.warning(f"LIME initialisation failed: {exc}") | |
| def explain_single( | |
| self, | |
| x: np.ndarray, | |
| top_n: int = 15, | |
| ) -> Dict: | |
| """Generate SHAP + LIME explanations for a single email. | |
| Args: | |
| x: Single email feature vector shape [n_features]. | |
| top_n: Number of top features to return per method. | |
| Returns: | |
| Dict with: | |
| - shap_features: List of {feature, shap_value, value} dicts | |
| - lime_features: List of {feature, weight, value} dicts | |
| - agreement_score: Float 0–1 measuring SHAP/LIME agreement | |
| - phishing_risk_features: Top features pushing toward phishing | |
| """ | |
| x_clean = np.nan_to_num(x.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0) | |
| result: Dict = { | |
| "shap_features": [], | |
| "lime_features": [], | |
| "agreement_score": 0.0, | |
| "phishing_risk_features": [], | |
| } | |
| # --- SHAP ----------------------------------------------------------- | |
| shap_top: List[Tuple[str, float]] = [] | |
| if self._shap_explainer is not None: | |
| try: | |
| shap_vals = self._shap_explainer.shap_values(x_clean.reshape(1, -1)) | |
| # For binary classifiers, shap_values is either a 2D array or list | |
| if isinstance(shap_vals, list): | |
| vals = shap_vals[1][0] # Class 1 (phishing) shap values | |
| else: | |
| vals = shap_vals[0] if shap_vals.ndim == 2 else shap_vals | |
| # Sort by absolute value | |
| sorted_idx = np.argsort(np.abs(vals))[::-1][:top_n] | |
| shap_top = [] | |
| for col_idx in sorted_idx: | |
| fname = self.feature_names[col_idx] if col_idx < len(self.feature_names) else f"feat_{col_idx}" | |
| shap_top.append((fname, float(vals[col_idx]))) | |
| result["shap_features"].append({ | |
| "feature": fname, | |
| "shap_value": float(vals[col_idx]), | |
| "value": float(x_clean[col_idx]), # correct column index | |
| }) | |
| # Phishing risk features: positive SHAP (push toward phishing) | |
| result["phishing_risk_features"] = [ | |
| {"feature": f, "shap_value": v} | |
| for f, v in sorted(shap_top, key=lambda t: t[1], reverse=True) | |
| if v > 0 | |
| ][:top_n] | |
| except Exception as exc: | |
| log.warning(f"SHAP explanation failed: {exc}") | |
| # --- LIME ----------------------------------------------------------- | |
| lime_top: List[Tuple[str, float]] = [] | |
| if self._lime_explainer is not None: | |
| try: | |
| lime_exp = self._lime_explainer.explain_instance( | |
| x_clean, | |
| self.model.predict_proba, | |
| num_features=top_n, | |
| labels=(1,), # Explain phishing class | |
| ) | |
| lime_top = lime_exp.as_list(label=1) | |
| # Build a name->column_index lookup for correct value retrieval | |
| _name_to_col: dict = {} | |
| if self.feature_names: | |
| _name_to_col = {n: i for i, n in enumerate(self.feature_names)} | |
| lime_features_out = [] | |
| for feat, weight in lime_top: | |
| # Extract base feature name from LIME condition string | |
| # e.g. "url_domain_length > 0.50" -> "url_domain_length" | |
| base_name = feat | |
| for op in (" <= ", " > ", " < ", " >= ", " = "): | |
| if op in feat: | |
| base_name = feat.split(op)[0].strip() | |
| break | |
| col_idx = _name_to_col.get(base_name) | |
| feat_val = float(x_clean[col_idx]) if col_idx is not None else 0.0 | |
| lime_features_out.append({ | |
| "feature": feat, | |
| "weight": float(weight), | |
| "value": feat_val, | |
| }) | |
| result["lime_features"] = lime_features_out | |
| except Exception as exc: | |
| log.warning(f"LIME explanation failed: {exc}") | |
| # --- Agreement score ------------------------------------------------ | |
| result["agreement_score"] = self._compute_agreement(shap_top, lime_top, top_n=5) | |
| return result | |
| def _compute_agreement( | |
| self, | |
| shap_top: List[Tuple[str, float]], | |
| lime_top: List[Tuple[str, float]], | |
| top_n: int = 5, | |
| ) -> float: | |
| """Compute Jaccard similarity between top-N SHAP and LIME features. | |
| Agreement score = |SHAP_top ∩ LIME_top| / |SHAP_top ∪ LIME_top| | |
| Args: | |
| shap_top: SHAP top features (name, value) tuples. | |
| lime_top: LIME top features (condition, weight) tuples. | |
| top_n: Top-N features to compare. | |
| Returns: | |
| Float in [0, 1]. 1.0 = perfect agreement. | |
| """ | |
| if not shap_top or not lime_top: | |
| return 0.0 | |
| # Extract base feature names from LIME (conditions include comparison operators) | |
| def _extract_name(lime_feat: str) -> str: | |
| # LIME produces e.g. "url_domain_length > 0.50" — extract base name | |
| for op in [" <= ", " > ", " < ", " >= ", " = "]: | |
| if op in lime_feat: | |
| return lime_feat.split(op)[0].strip() | |
| return lime_feat.strip() | |
| shap_names = set(name for name, _ in shap_top[:top_n]) | |
| lime_names = set(_extract_name(feat) for feat, _ in lime_top[:top_n]) | |
| intersection = len(shap_names & lime_names) | |
| union = len(shap_names | lime_names) | |
| return float(intersection / union) if union > 0 else 0.0 | |
| def batch_shap_values(self, X: np.ndarray) -> np.ndarray: | |
| """Compute SHAP values for a batch of emails. | |
| Used for population-level feature importance analysis and | |
| generating SHAP summary plots. | |
| Args: | |
| X: Feature matrix shape [n_samples, n_features]. | |
| Returns: | |
| SHAP values array shape [n_samples, n_features]. | |
| """ | |
| if self._shap_explainer is None: | |
| return np.zeros_like(X) | |
| X_clean = np.nan_to_num(X.astype(np.float32), nan=0.0, posinf=0.0, neginf=0.0) | |
| try: | |
| vals = self._shap_explainer.shap_values(X_clean) | |
| if isinstance(vals, list): | |
| return vals[1] # Return phishing class values | |
| return vals | |
| except Exception as exc: | |
| log.warning(f"Batch SHAP failed: {exc}") | |
| return np.zeros_like(X_clean) | |