SPG_ML / explainable_ai.py
meetmendapara's picture
Added Personalization Models
5059de5
"""
Explainable AI (XAI) Integration for COGNEXA
Provides SHAP, LIME, and feature importance explanations for model predictions.
This module offers:
- SHAP TreeExplainer for tree-based models (XGBoost, LightGBM, Random Forest)
- SHAP KernelExplainer for neural networks and other models
- LIME for local explanations
- Permutation-based feature importance
- Summary plots and visualizations
- Explanation caching and optimization
Version: 1.0.0
"""
import logging
import pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
from dataclasses import dataclass, asdict
import json
import numpy as np
import pandas as pd
try:
import shap
SHAP_AVAILABLE = True
except ImportError:
SHAP_AVAILABLE = False
try:
import lime
import lime.lime_tabular
LIME_AVAILABLE = True
except ImportError:
LIME_AVAILABLE = False
from sklearn.inspection import permutation_importance
logger = logging.getLogger(__name__)
EXPLANATIONS_DIR = Path(__file__).parent / "explanations"
EXPLANATIONS_DIR.mkdir(exist_ok=True)
@dataclass
class FeatureImportance:
"""Feature importance for a prediction"""
feature_name: str
importance_value: float
direction: str # "positive" or "negative"
impact_level: str # "high", "medium", "low"
contribution_to_prediction: float
@dataclass
class SHAPExplanation:
"""SHAP-based explanation for a prediction"""
prediction_id: str
predicted_value: float
actual_value: Optional[float] = None
base_value: Optional[float] = None
shap_values: Optional[Dict[str, float]] = None
top_positive_features: List[FeatureImportance] = None
top_negative_features: List[FeatureImportance] = None
explanation_text: str = ""
confidence_in_explanation: float = 0.0
@dataclass
class LIMEExplanation:
"""LIME-based explanation for a prediction"""
prediction_id: str
predicted_value: float
predicted_class: Optional[str] = None
probability: float = 0.0
local_feature_importance: List[FeatureImportance] = None
explanation_text: str = ""
intercept: float = 0.0
class SHAPExplainer:
"""SHAP-based model explanation interface"""
def __init__(self, model: Any, X_background: pd.DataFrame, feature_names: List[str]):
"""
Initialize SHAP explainer
Args:
model: Trained model
X_background: Background dataset for SHAP (typically training data sample)
feature_names: List of feature names
"""
if not SHAP_AVAILABLE:
raise ImportError("SHAP library not installed. Install with: pip install shap")
self.model = model
self.X_background = X_background
self.feature_names = feature_names
self.explainer = None
self.shap_values = None
self._initialize_explainer()
def _initialize_explainer(self):
"""Initialize appropriate SHAP explainer based on model type"""
model_type = type(self.model).__name__
logger.info(f"Initializing SHAP explainer for {model_type}...")
try:
# For tree-based models (fastest)
if hasattr(self.model, 'get_booster') or hasattr(self.model, 'feature_importances_'):
logger.info("Using SHAP TreeExplainer (tree-based model)")
self.explainer = shap.TreeExplainer(self.model)
else:
# Fallback to KernelExplainer (general-purpose, slower)
logger.info("Using SHAP KernelExplainer (general-purpose)")
self.explainer = shap.KernelExplainer(
self.model.predict,
shap.sample(self.X_background, 100) # Use sample for speed
)
except Exception as e:
logger.error(f"Failed to initialize SHAP explainer: {e}")
raise
def explain_instance(
self,
X_instance: pd.DataFrame,
instance_index: int = 0
) -> SHAPExplanation:
"""
Generate SHAP explanation for a single instance
Args:
X_instance: Instance to explain (single row)
instance_index: Index of instance in dataframe
Returns:
SHAPExplanation object
"""
if isinstance(X_instance, pd.DataFrame):
X_instance = X_instance.iloc[instance_index:instance_index + 1]
else:
X_instance = pd.DataFrame([X_instance], columns=self.feature_names)
# Get prediction
predicted_value = self.model.predict(X_instance)[0]
# Get SHAP values
shap_values = self.explainer.shap_values(X_instance)
# Handle case where shap_values is array vs list
if isinstance(shap_values, list):
shap_values = shap_values[0] # For multi-class, take first class
shap_values = shap_values[0] # Take first (only) instance
# Create SHAP values dict
shap_dict = dict(zip(self.feature_names, shap_values))
# Get base value
base_value = float(self.explainer.expected_value) \
if hasattr(self.explainer, 'expected_value') \
else 0.0
# Sort by absolute value
sorted_features = sorted(
[(k, v) for k, v in shap_dict.items()],
key=lambda x: abs(x[1]),
reverse=True
)
top_positive = [
FeatureImportance(
feature_name=name,
importance_value=value,
direction="positive" if value > 0 else "negative",
impact_level=self._get_impact_level(abs(value)),
contribution_to_prediction=value
)
for name, value in sorted_features[:5]
if value > 0
]
top_negative = [
FeatureImportance(
feature_name=name,
importance_value=value,
direction="negative",
impact_level=self._get_impact_level(abs(value)),
contribution_to_prediction=value
)
for name, value in sorted_features[:5]
if value < 0
]
explanation_text = self._generate_explanation_text(
predicted_value, base_value, top_positive, top_negative
)
return SHAPExplanation(
prediction_id=f"shap_{np.random.randint(100000, 999999)}",
predicted_value=predicted_value,
base_value=base_value,
shap_values=shap_dict,
top_positive_features=top_positive,
top_negative_features=top_negative,
explanation_text=explanation_text,
confidence_in_explanation=0.95
)
def explain_batch(
self,
X_batch: pd.DataFrame,
num_samples: Optional[int] = None
) -> List[SHAPExplanation]:
"""Explain a batch of instances"""
explanations = []
num_to_explain = num_samples or len(X_batch)
for i in range(min(num_to_explain, len(X_batch))):
try:
exp = self.explain_instance(X_batch, i)
explanations.append(exp)
except Exception as e:
logger.error(f"Failed to explain instance {i}: {e}")
return explanations
def feature_importance_summary(self) -> Dict[str, float]:
"""Get global feature importance (mean absolute SHAP values)"""
if self.shap_values is None:
self.shap_values = self.explainer.shap_values(self.X_background)
if isinstance(self.shap_values, list):
shap_array = self.shap_values[0]
else:
shap_array = self.shap_values
mean_abs_shap = np.mean(np.abs(shap_array), axis=0)
return dict(zip(self.feature_names, mean_abs_shap))
@staticmethod
def _get_impact_level(value: float, thresholds: Tuple[float, float] = (0.05, 0.15)) -> str:
"""Classify impact level based on SHAP value magnitude"""
if abs(value) > thresholds[1]:
return "high"
elif abs(value) > thresholds[0]:
return "medium"
else:
return "low"
@staticmethod
def _generate_explanation_text(
predicted_value: float,
base_value: float,
positive_features: List[FeatureImportance],
negative_features: List[FeatureImportance]
) -> str:
"""Generate human-readable explanation"""
lines = []
lines.append(f"Base prediction value: {base_value:.4f}")
lines.append(f"Predicted value: {predicted_value:.4f}")
lines.append(f"Change: +{predicted_value - base_value:.4f}")
if positive_features:
lines.append("\nTop positive contributors:")
for feat in positive_features[:3]:
lines.append(f" • {feat.feature_name}: +{feat.contribution_to_prediction:.4f}")
if negative_features:
lines.append("\nTop negative contributors:")
for feat in negative_features[:3]:
lines.append(f" • {feat.feature_name}: {feat.contribution_to_prediction:.4f}")
return "\n".join(lines)
class LIMEExplainer:
"""LIME-based model explanation interface"""
def __init__(
self,
model: Any,
X_training: pd.DataFrame,
feature_names: List[str],
class_names: Optional[List[str]] = None,
mode: str = "classification"
):
"""
Initialize LIME explainer
Args:
model: Trained model (should have predict or predict_proba method)
X_training: Training data (for LIME reference)
feature_names: List of feature names
class_names: List of class names (for classification)
mode: "classification" or "regression"
"""
if not LIME_AVAILABLE:
raise ImportError("LIME library not installed. Install with: pip install lime")
self.model = model
self.X_training = X_training
self.feature_names = feature_names
self.class_names = class_names or ["Negative", "Positive"]
self.mode = mode
if mode == "classification":
self.explainer = lime.lime_tabular.LimeTabularExplainer(
X_training.values,
feature_names=feature_names,
class_names=self.class_names,
mode="classification",
random_state=42
)
else:
self.explainer = lime.lime_tabular.LimeTabularExplainer(
X_training.values,
feature_names=feature_names,
mode="regression",
random_state=42
)
def explain_instance(
self,
X_instance: pd.DataFrame,
instance_index: int = 0,
num_features: int = 10
) -> LIMEExplanation:
"""
Generate LIME explanation for a single instance
Args:
X_instance: Instance to explain
instance_index: Index of instance
num_features: Number of features to include in explanation
Returns:
LIMEExplanation object
"""
if isinstance(X_instance, pd.DataFrame):
X_array = X_instance.iloc[instance_index].values
else:
X_array = X_instance
# Get prediction
if self.mode == "classification":
predicted_proba = self.model.predict_proba([X_array])[0]
predicted_class = self.model.predict([X_array])[0]
predicted_value = predicted_proba[predicted_class]
else:
predicted_value = self.model.predict([X_array])[0]
predicted_class = None
# Get LIME explanation
exp = self.explainer.explain_instance(
X_array,
self.model.predict_proba if self.mode == "classification" else self.model.predict,
num_features=num_features
)
# Extract feature contributions
contributions = []
for feature_idx, weight in exp.as_list():
# Parse feature description
if " <= " in feature_idx or " > " in feature_idx:
feature_name = feature_idx.split(" ")[0]
else:
feature_name = feature_idx
contributions.append(FeatureImportance(
feature_name=feature_name,
importance_value=weight,
direction="positive" if weight > 0 else "negative",
impact_level=self._get_impact_level(abs(weight)),
contribution_to_prediction=weight
))
explanation_text = self._generate_explanation_text(
predicted_value, contributions, predicted_class
)
return LIMEExplanation(
prediction_id=f"lime_{np.random.randint(100000, 999999)}",
predicted_value=predicted_value,
predicted_class=str(predicted_class) if predicted_class is not None else None,
probability=predicted_value,
local_feature_importance=contributions,
explanation_text=explanation_text,
intercept=exp.intercept[predicted_class] if self.mode == "classification" else exp.intercept
)
@staticmethod
def _get_impact_level(value: float, thresholds: Tuple[float, float] = (0.05, 0.15)) -> str:
"""Classify impact level"""
if abs(value) > thresholds[1]:
return "high"
elif abs(value) > thresholds[0]:
return "medium"
else:
return "low"
@staticmethod
def _generate_explanation_text(
predicted_value: float,
contributions: List[FeatureImportance],
predicted_class: Optional[str] = None
) -> str:
"""Generate human-readable explanation"""
lines = []
if predicted_class:
lines.append(f"Predicted class: {predicted_class}")
lines.append(f"Prediction score: {predicted_value:.4f}")
lines.append("\nTop contributing features:")
for contrib in contributions[:5]:
sign = "+" if contrib.importance_value > 0 else ""
lines.append(f" • {contrib.feature_name}: {sign}{contrib.importance_value:.4f}")
return "\n".join(lines)
class PermutationImportanceExplainer:
"""Permutation-based feature importance"""
def __init__(self, model: Any, X_test: pd.DataFrame, y_test: pd.Series, feature_names: List[str]):
"""Initialize permutation importance explainer"""
self.model = model
self.X_test = X_test
self.y_test = y_test
self.feature_names = feature_names
self.importance = None
def compute_importance(
self,
n_repeats: int = 10,
random_state: int = 42
) -> Dict[str, float]:
"""Compute permutation importance"""
logger.info(f"Computing permutation importance ({n_repeats} repeats)...")
result = permutation_importance(
self.model,
self.X_test,
self.y_test,
n_repeats=n_repeats,
random_state=random_state,
n_jobs=-1
)
self.importance = dict(zip(self.feature_names, result.importances_mean))
logger.info("Permutation importance computed successfully")
return self.importance
def get_top_features(self, n: int = 10) -> List[Tuple[str, float]]:
"""Get top N features by importance"""
if self.importance is None:
raise ValueError("Call compute_importance() first")
sorted_features = sorted(
self.importance.items(),
key=lambda x: x[1],
reverse=True
)
return sorted_features[:n]
class ExplanationCache:
"""Cache for explanations to avoid redundant computation"""
def __init__(self, cache_dir: Path = EXPLANATIONS_DIR):
self.cache_dir = cache_dir
self.cache_dir.mkdir(exist_ok=True)
self.memory_cache = {}
def get(self, key: str) -> Optional[Dict]:
"""Get explanation from cache"""
# Try memory cache first
if key in self.memory_cache:
return self.memory_cache[key]
# Try disk cache
cache_file = self.cache_dir / f"{key}.json"
if cache_file.exists():
with open(cache_file, "r") as f:
data = json.load(f)
self.memory_cache[key] = data
return data
return None
def set(self, key: str, value: Dict):
"""Store explanation in cache"""
# Store in memory
self.memory_cache[key] = value
# Store on disk
cache_file = self.cache_dir / f"{key}.json"
with open(cache_file, "w") as f:
json.dump(value, f, default=str)
def clear(self):
"""Clear cache"""
self.memory_cache.clear()
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
logger.info("="*80)
logger.info("COGNEXA Explainable AI (XAI) Module")
logger.info("="*80)
# This module is imported by other components
logger.info("XAI module loaded successfully")
logger.info(f"SHAP available: {SHAP_AVAILABLE}")
logger.info(f"LIME available: {LIME_AVAILABLE}")