Spaces:
Sleeping
Sleeping
| """ | |
| Explainable AI (XAI) Integration for COGNEXA | |
| Provides SHAP, LIME, and feature importance explanations for model predictions. | |
| This module offers: | |
| - SHAP TreeExplainer for tree-based models (XGBoost, LightGBM, Random Forest) | |
| - SHAP KernelExplainer for neural networks and other models | |
| - LIME for local explanations | |
| - Permutation-based feature importance | |
| - Summary plots and visualizations | |
| - Explanation caching and optimization | |
| Version: 1.0.0 | |
| """ | |
| import logging | |
| import pickle | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Any | |
| from dataclasses import dataclass, asdict | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| try: | |
| import shap | |
| SHAP_AVAILABLE = True | |
| except ImportError: | |
| SHAP_AVAILABLE = False | |
| try: | |
| import lime | |
| import lime.lime_tabular | |
| LIME_AVAILABLE = True | |
| except ImportError: | |
| LIME_AVAILABLE = False | |
| from sklearn.inspection import permutation_importance | |
| logger = logging.getLogger(__name__) | |
| EXPLANATIONS_DIR = Path(__file__).parent / "explanations" | |
| EXPLANATIONS_DIR.mkdir(exist_ok=True) | |
| class FeatureImportance: | |
| """Feature importance for a prediction""" | |
| feature_name: str | |
| importance_value: float | |
| direction: str # "positive" or "negative" | |
| impact_level: str # "high", "medium", "low" | |
| contribution_to_prediction: float | |
| class SHAPExplanation: | |
| """SHAP-based explanation for a prediction""" | |
| prediction_id: str | |
| predicted_value: float | |
| actual_value: Optional[float] = None | |
| base_value: Optional[float] = None | |
| shap_values: Optional[Dict[str, float]] = None | |
| top_positive_features: List[FeatureImportance] = None | |
| top_negative_features: List[FeatureImportance] = None | |
| explanation_text: str = "" | |
| confidence_in_explanation: float = 0.0 | |
| class LIMEExplanation: | |
| """LIME-based explanation for a prediction""" | |
| prediction_id: str | |
| predicted_value: float | |
| predicted_class: Optional[str] = None | |
| probability: float = 0.0 | |
| local_feature_importance: List[FeatureImportance] = None | |
| explanation_text: str = "" | |
| intercept: float = 0.0 | |
| class SHAPExplainer: | |
| """SHAP-based model explanation interface""" | |
| def __init__(self, model: Any, X_background: pd.DataFrame, feature_names: List[str]): | |
| """ | |
| Initialize SHAP explainer | |
| Args: | |
| model: Trained model | |
| X_background: Background dataset for SHAP (typically training data sample) | |
| feature_names: List of feature names | |
| """ | |
| if not SHAP_AVAILABLE: | |
| raise ImportError("SHAP library not installed. Install with: pip install shap") | |
| self.model = model | |
| self.X_background = X_background | |
| self.feature_names = feature_names | |
| self.explainer = None | |
| self.shap_values = None | |
| self._initialize_explainer() | |
| def _initialize_explainer(self): | |
| """Initialize appropriate SHAP explainer based on model type""" | |
| model_type = type(self.model).__name__ | |
| logger.info(f"Initializing SHAP explainer for {model_type}...") | |
| try: | |
| # For tree-based models (fastest) | |
| if hasattr(self.model, 'get_booster') or hasattr(self.model, 'feature_importances_'): | |
| logger.info("Using SHAP TreeExplainer (tree-based model)") | |
| self.explainer = shap.TreeExplainer(self.model) | |
| else: | |
| # Fallback to KernelExplainer (general-purpose, slower) | |
| logger.info("Using SHAP KernelExplainer (general-purpose)") | |
| self.explainer = shap.KernelExplainer( | |
| self.model.predict, | |
| shap.sample(self.X_background, 100) # Use sample for speed | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to initialize SHAP explainer: {e}") | |
| raise | |
| def explain_instance( | |
| self, | |
| X_instance: pd.DataFrame, | |
| instance_index: int = 0 | |
| ) -> SHAPExplanation: | |
| """ | |
| Generate SHAP explanation for a single instance | |
| Args: | |
| X_instance: Instance to explain (single row) | |
| instance_index: Index of instance in dataframe | |
| Returns: | |
| SHAPExplanation object | |
| """ | |
| if isinstance(X_instance, pd.DataFrame): | |
| X_instance = X_instance.iloc[instance_index:instance_index + 1] | |
| else: | |
| X_instance = pd.DataFrame([X_instance], columns=self.feature_names) | |
| # Get prediction | |
| predicted_value = self.model.predict(X_instance)[0] | |
| # Get SHAP values | |
| shap_values = self.explainer.shap_values(X_instance) | |
| # Handle case where shap_values is array vs list | |
| if isinstance(shap_values, list): | |
| shap_values = shap_values[0] # For multi-class, take first class | |
| shap_values = shap_values[0] # Take first (only) instance | |
| # Create SHAP values dict | |
| shap_dict = dict(zip(self.feature_names, shap_values)) | |
| # Get base value | |
| base_value = float(self.explainer.expected_value) \ | |
| if hasattr(self.explainer, 'expected_value') \ | |
| else 0.0 | |
| # Sort by absolute value | |
| sorted_features = sorted( | |
| [(k, v) for k, v in shap_dict.items()], | |
| key=lambda x: abs(x[1]), | |
| reverse=True | |
| ) | |
| top_positive = [ | |
| FeatureImportance( | |
| feature_name=name, | |
| importance_value=value, | |
| direction="positive" if value > 0 else "negative", | |
| impact_level=self._get_impact_level(abs(value)), | |
| contribution_to_prediction=value | |
| ) | |
| for name, value in sorted_features[:5] | |
| if value > 0 | |
| ] | |
| top_negative = [ | |
| FeatureImportance( | |
| feature_name=name, | |
| importance_value=value, | |
| direction="negative", | |
| impact_level=self._get_impact_level(abs(value)), | |
| contribution_to_prediction=value | |
| ) | |
| for name, value in sorted_features[:5] | |
| if value < 0 | |
| ] | |
| explanation_text = self._generate_explanation_text( | |
| predicted_value, base_value, top_positive, top_negative | |
| ) | |
| return SHAPExplanation( | |
| prediction_id=f"shap_{np.random.randint(100000, 999999)}", | |
| predicted_value=predicted_value, | |
| base_value=base_value, | |
| shap_values=shap_dict, | |
| top_positive_features=top_positive, | |
| top_negative_features=top_negative, | |
| explanation_text=explanation_text, | |
| confidence_in_explanation=0.95 | |
| ) | |
| def explain_batch( | |
| self, | |
| X_batch: pd.DataFrame, | |
| num_samples: Optional[int] = None | |
| ) -> List[SHAPExplanation]: | |
| """Explain a batch of instances""" | |
| explanations = [] | |
| num_to_explain = num_samples or len(X_batch) | |
| for i in range(min(num_to_explain, len(X_batch))): | |
| try: | |
| exp = self.explain_instance(X_batch, i) | |
| explanations.append(exp) | |
| except Exception as e: | |
| logger.error(f"Failed to explain instance {i}: {e}") | |
| return explanations | |
| def feature_importance_summary(self) -> Dict[str, float]: | |
| """Get global feature importance (mean absolute SHAP values)""" | |
| if self.shap_values is None: | |
| self.shap_values = self.explainer.shap_values(self.X_background) | |
| if isinstance(self.shap_values, list): | |
| shap_array = self.shap_values[0] | |
| else: | |
| shap_array = self.shap_values | |
| mean_abs_shap = np.mean(np.abs(shap_array), axis=0) | |
| return dict(zip(self.feature_names, mean_abs_shap)) | |
| def _get_impact_level(value: float, thresholds: Tuple[float, float] = (0.05, 0.15)) -> str: | |
| """Classify impact level based on SHAP value magnitude""" | |
| if abs(value) > thresholds[1]: | |
| return "high" | |
| elif abs(value) > thresholds[0]: | |
| return "medium" | |
| else: | |
| return "low" | |
| def _generate_explanation_text( | |
| predicted_value: float, | |
| base_value: float, | |
| positive_features: List[FeatureImportance], | |
| negative_features: List[FeatureImportance] | |
| ) -> str: | |
| """Generate human-readable explanation""" | |
| lines = [] | |
| lines.append(f"Base prediction value: {base_value:.4f}") | |
| lines.append(f"Predicted value: {predicted_value:.4f}") | |
| lines.append(f"Change: +{predicted_value - base_value:.4f}") | |
| if positive_features: | |
| lines.append("\nTop positive contributors:") | |
| for feat in positive_features[:3]: | |
| lines.append(f" • {feat.feature_name}: +{feat.contribution_to_prediction:.4f}") | |
| if negative_features: | |
| lines.append("\nTop negative contributors:") | |
| for feat in negative_features[:3]: | |
| lines.append(f" • {feat.feature_name}: {feat.contribution_to_prediction:.4f}") | |
| return "\n".join(lines) | |
| class LIMEExplainer: | |
| """LIME-based model explanation interface""" | |
| def __init__( | |
| self, | |
| model: Any, | |
| X_training: pd.DataFrame, | |
| feature_names: List[str], | |
| class_names: Optional[List[str]] = None, | |
| mode: str = "classification" | |
| ): | |
| """ | |
| Initialize LIME explainer | |
| Args: | |
| model: Trained model (should have predict or predict_proba method) | |
| X_training: Training data (for LIME reference) | |
| feature_names: List of feature names | |
| class_names: List of class names (for classification) | |
| mode: "classification" or "regression" | |
| """ | |
| if not LIME_AVAILABLE: | |
| raise ImportError("LIME library not installed. Install with: pip install lime") | |
| self.model = model | |
| self.X_training = X_training | |
| self.feature_names = feature_names | |
| self.class_names = class_names or ["Negative", "Positive"] | |
| self.mode = mode | |
| if mode == "classification": | |
| self.explainer = lime.lime_tabular.LimeTabularExplainer( | |
| X_training.values, | |
| feature_names=feature_names, | |
| class_names=self.class_names, | |
| mode="classification", | |
| random_state=42 | |
| ) | |
| else: | |
| self.explainer = lime.lime_tabular.LimeTabularExplainer( | |
| X_training.values, | |
| feature_names=feature_names, | |
| mode="regression", | |
| random_state=42 | |
| ) | |
| def explain_instance( | |
| self, | |
| X_instance: pd.DataFrame, | |
| instance_index: int = 0, | |
| num_features: int = 10 | |
| ) -> LIMEExplanation: | |
| """ | |
| Generate LIME explanation for a single instance | |
| Args: | |
| X_instance: Instance to explain | |
| instance_index: Index of instance | |
| num_features: Number of features to include in explanation | |
| Returns: | |
| LIMEExplanation object | |
| """ | |
| if isinstance(X_instance, pd.DataFrame): | |
| X_array = X_instance.iloc[instance_index].values | |
| else: | |
| X_array = X_instance | |
| # Get prediction | |
| if self.mode == "classification": | |
| predicted_proba = self.model.predict_proba([X_array])[0] | |
| predicted_class = self.model.predict([X_array])[0] | |
| predicted_value = predicted_proba[predicted_class] | |
| else: | |
| predicted_value = self.model.predict([X_array])[0] | |
| predicted_class = None | |
| # Get LIME explanation | |
| exp = self.explainer.explain_instance( | |
| X_array, | |
| self.model.predict_proba if self.mode == "classification" else self.model.predict, | |
| num_features=num_features | |
| ) | |
| # Extract feature contributions | |
| contributions = [] | |
| for feature_idx, weight in exp.as_list(): | |
| # Parse feature description | |
| if " <= " in feature_idx or " > " in feature_idx: | |
| feature_name = feature_idx.split(" ")[0] | |
| else: | |
| feature_name = feature_idx | |
| contributions.append(FeatureImportance( | |
| feature_name=feature_name, | |
| importance_value=weight, | |
| direction="positive" if weight > 0 else "negative", | |
| impact_level=self._get_impact_level(abs(weight)), | |
| contribution_to_prediction=weight | |
| )) | |
| explanation_text = self._generate_explanation_text( | |
| predicted_value, contributions, predicted_class | |
| ) | |
| return LIMEExplanation( | |
| prediction_id=f"lime_{np.random.randint(100000, 999999)}", | |
| predicted_value=predicted_value, | |
| predicted_class=str(predicted_class) if predicted_class is not None else None, | |
| probability=predicted_value, | |
| local_feature_importance=contributions, | |
| explanation_text=explanation_text, | |
| intercept=exp.intercept[predicted_class] if self.mode == "classification" else exp.intercept | |
| ) | |
| def _get_impact_level(value: float, thresholds: Tuple[float, float] = (0.05, 0.15)) -> str: | |
| """Classify impact level""" | |
| if abs(value) > thresholds[1]: | |
| return "high" | |
| elif abs(value) > thresholds[0]: | |
| return "medium" | |
| else: | |
| return "low" | |
| def _generate_explanation_text( | |
| predicted_value: float, | |
| contributions: List[FeatureImportance], | |
| predicted_class: Optional[str] = None | |
| ) -> str: | |
| """Generate human-readable explanation""" | |
| lines = [] | |
| if predicted_class: | |
| lines.append(f"Predicted class: {predicted_class}") | |
| lines.append(f"Prediction score: {predicted_value:.4f}") | |
| lines.append("\nTop contributing features:") | |
| for contrib in contributions[:5]: | |
| sign = "+" if contrib.importance_value > 0 else "" | |
| lines.append(f" • {contrib.feature_name}: {sign}{contrib.importance_value:.4f}") | |
| return "\n".join(lines) | |
| class PermutationImportanceExplainer: | |
| """Permutation-based feature importance""" | |
| def __init__(self, model: Any, X_test: pd.DataFrame, y_test: pd.Series, feature_names: List[str]): | |
| """Initialize permutation importance explainer""" | |
| self.model = model | |
| self.X_test = X_test | |
| self.y_test = y_test | |
| self.feature_names = feature_names | |
| self.importance = None | |
| def compute_importance( | |
| self, | |
| n_repeats: int = 10, | |
| random_state: int = 42 | |
| ) -> Dict[str, float]: | |
| """Compute permutation importance""" | |
| logger.info(f"Computing permutation importance ({n_repeats} repeats)...") | |
| result = permutation_importance( | |
| self.model, | |
| self.X_test, | |
| self.y_test, | |
| n_repeats=n_repeats, | |
| random_state=random_state, | |
| n_jobs=-1 | |
| ) | |
| self.importance = dict(zip(self.feature_names, result.importances_mean)) | |
| logger.info("Permutation importance computed successfully") | |
| return self.importance | |
| def get_top_features(self, n: int = 10) -> List[Tuple[str, float]]: | |
| """Get top N features by importance""" | |
| if self.importance is None: | |
| raise ValueError("Call compute_importance() first") | |
| sorted_features = sorted( | |
| self.importance.items(), | |
| key=lambda x: x[1], | |
| reverse=True | |
| ) | |
| return sorted_features[:n] | |
| class ExplanationCache: | |
| """Cache for explanations to avoid redundant computation""" | |
| def __init__(self, cache_dir: Path = EXPLANATIONS_DIR): | |
| self.cache_dir = cache_dir | |
| self.cache_dir.mkdir(exist_ok=True) | |
| self.memory_cache = {} | |
| def get(self, key: str) -> Optional[Dict]: | |
| """Get explanation from cache""" | |
| # Try memory cache first | |
| if key in self.memory_cache: | |
| return self.memory_cache[key] | |
| # Try disk cache | |
| cache_file = self.cache_dir / f"{key}.json" | |
| if cache_file.exists(): | |
| with open(cache_file, "r") as f: | |
| data = json.load(f) | |
| self.memory_cache[key] = data | |
| return data | |
| return None | |
| def set(self, key: str, value: Dict): | |
| """Store explanation in cache""" | |
| # Store in memory | |
| self.memory_cache[key] = value | |
| # Store on disk | |
| cache_file = self.cache_dir / f"{key}.json" | |
| with open(cache_file, "w") as f: | |
| json.dump(value, f, default=str) | |
| def clear(self): | |
| """Clear cache""" | |
| self.memory_cache.clear() | |
| if __name__ == "__main__": | |
| logging.basicConfig(level=logging.INFO) | |
| logger.info("="*80) | |
| logger.info("COGNEXA Explainable AI (XAI) Module") | |
| logger.info("="*80) | |
| # This module is imported by other components | |
| logger.info("XAI module loaded successfully") | |
| logger.info(f"SHAP available: {SHAP_AVAILABLE}") | |
| logger.info(f"LIME available: {LIME_AVAILABLE}") | |