""" Explainable AI (XAI) Integration for COGNEXA Provides SHAP, LIME, and feature importance explanations for model predictions. This module offers: - SHAP TreeExplainer for tree-based models (XGBoost, LightGBM, Random Forest) - SHAP KernelExplainer for neural networks and other models - LIME for local explanations - Permutation-based feature importance - Summary plots and visualizations - Explanation caching and optimization Version: 1.0.0 """ import logging import pickle from pathlib import Path from typing import Dict, List, Optional, Tuple, Any from dataclasses import dataclass, asdict import json import numpy as np import pandas as pd try: import shap SHAP_AVAILABLE = True except ImportError: SHAP_AVAILABLE = False try: import lime import lime.lime_tabular LIME_AVAILABLE = True except ImportError: LIME_AVAILABLE = False from sklearn.inspection import permutation_importance logger = logging.getLogger(__name__) EXPLANATIONS_DIR = Path(__file__).parent / "explanations" EXPLANATIONS_DIR.mkdir(exist_ok=True) @dataclass class FeatureImportance: """Feature importance for a prediction""" feature_name: str importance_value: float direction: str # "positive" or "negative" impact_level: str # "high", "medium", "low" contribution_to_prediction: float @dataclass class SHAPExplanation: """SHAP-based explanation for a prediction""" prediction_id: str predicted_value: float actual_value: Optional[float] = None base_value: Optional[float] = None shap_values: Optional[Dict[str, float]] = None top_positive_features: List[FeatureImportance] = None top_negative_features: List[FeatureImportance] = None explanation_text: str = "" confidence_in_explanation: float = 0.0 @dataclass class LIMEExplanation: """LIME-based explanation for a prediction""" prediction_id: str predicted_value: float predicted_class: Optional[str] = None probability: float = 0.0 local_feature_importance: List[FeatureImportance] = None explanation_text: str = "" intercept: float = 0.0 class SHAPExplainer: """SHAP-based model explanation interface""" def __init__(self, model: Any, X_background: pd.DataFrame, feature_names: List[str]): """ Initialize SHAP explainer Args: model: Trained model X_background: Background dataset for SHAP (typically training data sample) feature_names: List of feature names """ if not SHAP_AVAILABLE: raise ImportError("SHAP library not installed. Install with: pip install shap") self.model = model self.X_background = X_background self.feature_names = feature_names self.explainer = None self.shap_values = None self._initialize_explainer() def _initialize_explainer(self): """Initialize appropriate SHAP explainer based on model type""" model_type = type(self.model).__name__ logger.info(f"Initializing SHAP explainer for {model_type}...") try: # For tree-based models (fastest) if hasattr(self.model, 'get_booster') or hasattr(self.model, 'feature_importances_'): logger.info("Using SHAP TreeExplainer (tree-based model)") self.explainer = shap.TreeExplainer(self.model) else: # Fallback to KernelExplainer (general-purpose, slower) logger.info("Using SHAP KernelExplainer (general-purpose)") self.explainer = shap.KernelExplainer( self.model.predict, shap.sample(self.X_background, 100) # Use sample for speed ) except Exception as e: logger.error(f"Failed to initialize SHAP explainer: {e}") raise def explain_instance( self, X_instance: pd.DataFrame, instance_index: int = 0 ) -> SHAPExplanation: """ Generate SHAP explanation for a single instance Args: X_instance: Instance to explain (single row) instance_index: Index of instance in dataframe Returns: SHAPExplanation object """ if isinstance(X_instance, pd.DataFrame): X_instance = X_instance.iloc[instance_index:instance_index + 1] else: X_instance = pd.DataFrame([X_instance], columns=self.feature_names) # Get prediction predicted_value = self.model.predict(X_instance)[0] # Get SHAP values shap_values = self.explainer.shap_values(X_instance) # Handle case where shap_values is array vs list if isinstance(shap_values, list): shap_values = shap_values[0] # For multi-class, take first class shap_values = shap_values[0] # Take first (only) instance # Create SHAP values dict shap_dict = dict(zip(self.feature_names, shap_values)) # Get base value base_value = float(self.explainer.expected_value) \ if hasattr(self.explainer, 'expected_value') \ else 0.0 # Sort by absolute value sorted_features = sorted( [(k, v) for k, v in shap_dict.items()], key=lambda x: abs(x[1]), reverse=True ) top_positive = [ FeatureImportance( feature_name=name, importance_value=value, direction="positive" if value > 0 else "negative", impact_level=self._get_impact_level(abs(value)), contribution_to_prediction=value ) for name, value in sorted_features[:5] if value > 0 ] top_negative = [ FeatureImportance( feature_name=name, importance_value=value, direction="negative", impact_level=self._get_impact_level(abs(value)), contribution_to_prediction=value ) for name, value in sorted_features[:5] if value < 0 ] explanation_text = self._generate_explanation_text( predicted_value, base_value, top_positive, top_negative ) return SHAPExplanation( prediction_id=f"shap_{np.random.randint(100000, 999999)}", predicted_value=predicted_value, base_value=base_value, shap_values=shap_dict, top_positive_features=top_positive, top_negative_features=top_negative, explanation_text=explanation_text, confidence_in_explanation=0.95 ) def explain_batch( self, X_batch: pd.DataFrame, num_samples: Optional[int] = None ) -> List[SHAPExplanation]: """Explain a batch of instances""" explanations = [] num_to_explain = num_samples or len(X_batch) for i in range(min(num_to_explain, len(X_batch))): try: exp = self.explain_instance(X_batch, i) explanations.append(exp) except Exception as e: logger.error(f"Failed to explain instance {i}: {e}") return explanations def feature_importance_summary(self) -> Dict[str, float]: """Get global feature importance (mean absolute SHAP values)""" if self.shap_values is None: self.shap_values = self.explainer.shap_values(self.X_background) if isinstance(self.shap_values, list): shap_array = self.shap_values[0] else: shap_array = self.shap_values mean_abs_shap = np.mean(np.abs(shap_array), axis=0) return dict(zip(self.feature_names, mean_abs_shap)) @staticmethod def _get_impact_level(value: float, thresholds: Tuple[float, float] = (0.05, 0.15)) -> str: """Classify impact level based on SHAP value magnitude""" if abs(value) > thresholds[1]: return "high" elif abs(value) > thresholds[0]: return "medium" else: return "low" @staticmethod def _generate_explanation_text( predicted_value: float, base_value: float, positive_features: List[FeatureImportance], negative_features: List[FeatureImportance] ) -> str: """Generate human-readable explanation""" lines = [] lines.append(f"Base prediction value: {base_value:.4f}") lines.append(f"Predicted value: {predicted_value:.4f}") lines.append(f"Change: +{predicted_value - base_value:.4f}") if positive_features: lines.append("\nTop positive contributors:") for feat in positive_features[:3]: lines.append(f" • {feat.feature_name}: +{feat.contribution_to_prediction:.4f}") if negative_features: lines.append("\nTop negative contributors:") for feat in negative_features[:3]: lines.append(f" • {feat.feature_name}: {feat.contribution_to_prediction:.4f}") return "\n".join(lines) class LIMEExplainer: """LIME-based model explanation interface""" def __init__( self, model: Any, X_training: pd.DataFrame, feature_names: List[str], class_names: Optional[List[str]] = None, mode: str = "classification" ): """ Initialize LIME explainer Args: model: Trained model (should have predict or predict_proba method) X_training: Training data (for LIME reference) feature_names: List of feature names class_names: List of class names (for classification) mode: "classification" or "regression" """ if not LIME_AVAILABLE: raise ImportError("LIME library not installed. Install with: pip install lime") self.model = model self.X_training = X_training self.feature_names = feature_names self.class_names = class_names or ["Negative", "Positive"] self.mode = mode if mode == "classification": self.explainer = lime.lime_tabular.LimeTabularExplainer( X_training.values, feature_names=feature_names, class_names=self.class_names, mode="classification", random_state=42 ) else: self.explainer = lime.lime_tabular.LimeTabularExplainer( X_training.values, feature_names=feature_names, mode="regression", random_state=42 ) def explain_instance( self, X_instance: pd.DataFrame, instance_index: int = 0, num_features: int = 10 ) -> LIMEExplanation: """ Generate LIME explanation for a single instance Args: X_instance: Instance to explain instance_index: Index of instance num_features: Number of features to include in explanation Returns: LIMEExplanation object """ if isinstance(X_instance, pd.DataFrame): X_array = X_instance.iloc[instance_index].values else: X_array = X_instance # Get prediction if self.mode == "classification": predicted_proba = self.model.predict_proba([X_array])[0] predicted_class = self.model.predict([X_array])[0] predicted_value = predicted_proba[predicted_class] else: predicted_value = self.model.predict([X_array])[0] predicted_class = None # Get LIME explanation exp = self.explainer.explain_instance( X_array, self.model.predict_proba if self.mode == "classification" else self.model.predict, num_features=num_features ) # Extract feature contributions contributions = [] for feature_idx, weight in exp.as_list(): # Parse feature description if " <= " in feature_idx or " > " in feature_idx: feature_name = feature_idx.split(" ")[0] else: feature_name = feature_idx contributions.append(FeatureImportance( feature_name=feature_name, importance_value=weight, direction="positive" if weight > 0 else "negative", impact_level=self._get_impact_level(abs(weight)), contribution_to_prediction=weight )) explanation_text = self._generate_explanation_text( predicted_value, contributions, predicted_class ) return LIMEExplanation( prediction_id=f"lime_{np.random.randint(100000, 999999)}", predicted_value=predicted_value, predicted_class=str(predicted_class) if predicted_class is not None else None, probability=predicted_value, local_feature_importance=contributions, explanation_text=explanation_text, intercept=exp.intercept[predicted_class] if self.mode == "classification" else exp.intercept ) @staticmethod def _get_impact_level(value: float, thresholds: Tuple[float, float] = (0.05, 0.15)) -> str: """Classify impact level""" if abs(value) > thresholds[1]: return "high" elif abs(value) > thresholds[0]: return "medium" else: return "low" @staticmethod def _generate_explanation_text( predicted_value: float, contributions: List[FeatureImportance], predicted_class: Optional[str] = None ) -> str: """Generate human-readable explanation""" lines = [] if predicted_class: lines.append(f"Predicted class: {predicted_class}") lines.append(f"Prediction score: {predicted_value:.4f}") lines.append("\nTop contributing features:") for contrib in contributions[:5]: sign = "+" if contrib.importance_value > 0 else "" lines.append(f" • {contrib.feature_name}: {sign}{contrib.importance_value:.4f}") return "\n".join(lines) class PermutationImportanceExplainer: """Permutation-based feature importance""" def __init__(self, model: Any, X_test: pd.DataFrame, y_test: pd.Series, feature_names: List[str]): """Initialize permutation importance explainer""" self.model = model self.X_test = X_test self.y_test = y_test self.feature_names = feature_names self.importance = None def compute_importance( self, n_repeats: int = 10, random_state: int = 42 ) -> Dict[str, float]: """Compute permutation importance""" logger.info(f"Computing permutation importance ({n_repeats} repeats)...") result = permutation_importance( self.model, self.X_test, self.y_test, n_repeats=n_repeats, random_state=random_state, n_jobs=-1 ) self.importance = dict(zip(self.feature_names, result.importances_mean)) logger.info("Permutation importance computed successfully") return self.importance def get_top_features(self, n: int = 10) -> List[Tuple[str, float]]: """Get top N features by importance""" if self.importance is None: raise ValueError("Call compute_importance() first") sorted_features = sorted( self.importance.items(), key=lambda x: x[1], reverse=True ) return sorted_features[:n] class ExplanationCache: """Cache for explanations to avoid redundant computation""" def __init__(self, cache_dir: Path = EXPLANATIONS_DIR): self.cache_dir = cache_dir self.cache_dir.mkdir(exist_ok=True) self.memory_cache = {} def get(self, key: str) -> Optional[Dict]: """Get explanation from cache""" # Try memory cache first if key in self.memory_cache: return self.memory_cache[key] # Try disk cache cache_file = self.cache_dir / f"{key}.json" if cache_file.exists(): with open(cache_file, "r") as f: data = json.load(f) self.memory_cache[key] = data return data return None def set(self, key: str, value: Dict): """Store explanation in cache""" # Store in memory self.memory_cache[key] = value # Store on disk cache_file = self.cache_dir / f"{key}.json" with open(cache_file, "w") as f: json.dump(value, f, default=str) def clear(self): """Clear cache""" self.memory_cache.clear() if __name__ == "__main__": logging.basicConfig(level=logging.INFO) logger.info("="*80) logger.info("COGNEXA Explainable AI (XAI) Module") logger.info("="*80) # This module is imported by other components logger.info("XAI module loaded successfully") logger.info(f"SHAP available: {SHAP_AVAILABLE}") logger.info(f"LIME available: {LIME_AVAILABLE}")