nananie143's picture
Upload src/explainability/shap_explainer.py with huggingface_hub
8c80331 verified
"""
SHAP and LIME Explainability Module
Provides interpretable AI explanations for predictions.
Based on the blueprint for model explainability.
"""
import numpy as np
import pandas as pd
from typing import Dict, List, Tuple, Optional, Any
import logging
from pathlib import Path
logger = logging.getLogger(__name__)
# Check for explainability libraries
try:
import shap
HAS_SHAP = True
except ImportError:
HAS_SHAP = False
logger.warning("SHAP not installed. Install with: pip install shap")
try:
import lime
import lime.lime_tabular
HAS_LIME = True
except ImportError:
HAS_LIME = False
logger.warning("LIME not installed. Install with: pip install lime")
class SHAPExplainer:
"""
SHAP-based model explainability.
Provides feature importance and prediction explanations.
"""
def __init__(self, model: Any, feature_names: List[str] = None):
self.model = model
self.feature_names = feature_names
self.explainer = None
self.background_data = None
def fit(self, background_data: np.ndarray, sample_size: int = 100):
"""Initialize explainer with background data."""
if not HAS_SHAP:
logger.warning("SHAP not available")
return
# Sample background data
if len(background_data) > sample_size:
indices = np.random.choice(len(background_data), sample_size, replace=False)
self.background_data = background_data[indices]
else:
self.background_data = background_data
# Create explainer
try:
# Try TreeExplainer for tree-based models
self.explainer = shap.TreeExplainer(self.model)
logger.info("Using TreeExplainer")
except Exception:
try:
# Fall back to KernelExplainer
self.explainer = shap.KernelExplainer(
self.model.predict_proba if hasattr(self.model, 'predict_proba') else self.model.predict,
self.background_data
)
logger.info("Using KernelExplainer")
except Exception as e:
logger.error(f"Failed to create SHAP explainer: {e}")
def explain_prediction(
self,
features: np.ndarray,
class_index: int = None
) -> Dict:
"""
Explain a single prediction.
Returns:
Dictionary with feature importances and base value
"""
if self.explainer is None:
return self._fallback_explanation(features)
try:
shap_values = self.explainer.shap_values(features.reshape(1, -1))
# Handle multi-class
if isinstance(shap_values, list):
if class_index is not None:
values = shap_values[class_index][0]
else:
values = shap_values[1][0] # Default to positive class
else:
values = shap_values[0]
# Get feature importance ranking
importance = np.abs(values)
sorted_idx = np.argsort(importance)[::-1]
top_features = []
for idx in sorted_idx[:10]: # Top 10 features
if self.feature_names and idx < len(self.feature_names):
name = self.feature_names[idx]
else:
name = f"Feature_{idx}"
top_features.append({
'feature': name,
'importance': float(importance[idx]),
'contribution': float(values[idx]),
'direction': 'positive' if values[idx] > 0 else 'negative'
})
return {
'top_features': top_features,
'base_value': float(self.explainer.expected_value[0] if isinstance(self.explainer.expected_value, np.ndarray) else self.explainer.expected_value),
'total_contribution': float(np.sum(values))
}
except Exception as e:
logger.error(f"SHAP explanation failed: {e}")
return self._fallback_explanation(features)
def get_global_importance(self, X: np.ndarray) -> pd.DataFrame:
"""Get global feature importance across dataset."""
if self.explainer is None:
return pd.DataFrame()
try:
shap_values = self.explainer.shap_values(X)
if isinstance(shap_values, list):
values = shap_values[1] # Positive class
else:
values = shap_values
importance = np.abs(values).mean(axis=0)
df = pd.DataFrame({
'feature': self.feature_names if self.feature_names else [f'Feature_{i}' for i in range(len(importance))],
'importance': importance
})
return df.sort_values('importance', ascending=False)
except Exception as e:
logger.error(f"Global importance failed: {e}")
return pd.DataFrame()
def _fallback_explanation(self, features: np.ndarray) -> Dict:
"""Simple fallback when SHAP unavailable."""
# Use feature magnitudes as proxy
importance = np.abs(features)
sorted_idx = np.argsort(importance)[::-1]
top_features = []
for idx in sorted_idx[:10]:
if self.feature_names and idx < len(self.feature_names):
name = self.feature_names[idx]
else:
name = f"Feature_{idx}"
top_features.append({
'feature': name,
'importance': float(importance[idx]),
'contribution': float(features[idx]),
'direction': 'positive' if features[idx] > 0 else 'negative'
})
return {
'top_features': top_features,
'base_value': 0.0,
'total_contribution': 0.0,
'note': 'Fallback explanation (SHAP unavailable)'
}
class LIMEExplainer:
"""
LIME-based model explainability.
Provides local interpretable model-agnostic explanations.
"""
def __init__(
self,
model: Any,
feature_names: List[str] = None,
class_names: List[str] = None
):
self.model = model
self.feature_names = feature_names
self.class_names = class_names or ['Away', 'Draw', 'Home']
self.explainer = None
def fit(self, training_data: np.ndarray):
"""Initialize LIME explainer with training data."""
if not HAS_LIME:
logger.warning("LIME not available")
return
self.explainer = lime.lime_tabular.LimeTabularExplainer(
training_data,
feature_names=self.feature_names,
class_names=self.class_names,
mode='classification'
)
logger.info("LIME explainer initialized")
def explain_prediction(
self,
features: np.ndarray,
num_features: int = 10
) -> Dict:
"""
Explain a single prediction using LIME.
"""
if self.explainer is None:
return self._fallback_explanation(features)
try:
# Get prediction function
if hasattr(self.model, 'predict_proba'):
predict_fn = self.model.predict_proba
else:
predict_fn = lambda x: self.model.predict(x)
explanation = self.explainer.explain_instance(
features,
predict_fn,
num_features=num_features
)
# Extract feature contributions
feature_weights = explanation.as_list()
top_features = []
for feature_desc, weight in feature_weights:
top_features.append({
'feature': feature_desc,
'importance': abs(weight),
'contribution': weight,
'direction': 'positive' if weight > 0 else 'negative'
})
return {
'top_features': top_features,
'local_prediction': explanation.local_pred[0] if hasattr(explanation, 'local_pred') else None,
'score': explanation.score
}
except Exception as e:
logger.error(f"LIME explanation failed: {e}")
return self._fallback_explanation(features)
def _fallback_explanation(self, features: np.ndarray) -> Dict:
"""Simple fallback when LIME unavailable."""
importance = np.abs(features)
sorted_idx = np.argsort(importance)[::-1]
top_features = []
for idx in sorted_idx[:10]:
if self.feature_names and idx < len(self.feature_names):
name = self.feature_names[idx]
else:
name = f"Feature_{idx}"
top_features.append({
'feature': name,
'importance': float(importance[idx]),
'contribution': float(features[idx]),
'direction': 'positive' if features[idx] > 0 else 'negative'
})
return {
'top_features': top_features,
'note': 'Fallback explanation (LIME unavailable)'
}
class PredictionExplainer:
"""
Combined explainability system using SHAP and LIME.
Provides comprehensive prediction explanations.
"""
def __init__(
self,
model: Any,
feature_names: List[str] = None,
use_shap: bool = True,
use_lime: bool = True
):
self.model = model
self.feature_names = feature_names
self.shap_explainer = None
self.lime_explainer = None
if use_shap and HAS_SHAP:
self.shap_explainer = SHAPExplainer(model, feature_names)
if use_lime and HAS_LIME:
self.lime_explainer = LIMEExplainer(model, feature_names)
def fit(self, training_data: np.ndarray):
"""Initialize all explainers."""
if self.shap_explainer:
self.shap_explainer.fit(training_data)
if self.lime_explainer:
self.lime_explainer.fit(training_data)
def explain(
self,
features: np.ndarray,
prediction: Dict = None
) -> Dict:
"""
Generate comprehensive explanation for a prediction.
"""
result = {
'prediction': prediction,
'feature_values': {}
}
# Add feature values
if self.feature_names:
for i, name in enumerate(self.feature_names[:20]): # Top 20
if i < len(features):
result['feature_values'][name] = float(features[i])
# SHAP explanation
if self.shap_explainer:
result['shap'] = self.shap_explainer.explain_prediction(features)
# LIME explanation
if self.lime_explainer:
result['lime'] = self.lime_explainer.explain_prediction(features)
# Generate human-readable summary
result['summary'] = self._generate_summary(result)
return result
def _generate_summary(self, explanation: Dict) -> str:
"""Generate human-readable summary."""
summary_parts = []
if 'shap' in explanation and explanation['shap'].get('top_features'):
top_3 = explanation['shap']['top_features'][:3]
positive_factors = [f['feature'] for f in top_3 if f['direction'] == 'positive']
negative_factors = [f['feature'] for f in top_3 if f['direction'] == 'negative']
if positive_factors:
summary_parts.append(f"Key positive factors: {', '.join(positive_factors)}")
if negative_factors:
summary_parts.append(f"Key negative factors: {', '.join(negative_factors)}")
return '. '.join(summary_parts) if summary_parts else "No explanation available"
# Global instance
_explainer: Optional[PredictionExplainer] = None
def get_explainer(model: Any = None, feature_names: List[str] = None) -> PredictionExplainer:
"""Get or create prediction explainer."""
global _explainer
if _explainer is None and model is not None:
_explainer = PredictionExplainer(model, feature_names)
return _explainer
def explain_prediction(features: np.ndarray, prediction: Dict = None, model: Any = None) -> Dict:
"""Quick function to explain a prediction."""
explainer = get_explainer(model)
if explainer:
return explainer.explain(features, prediction)
return {'error': 'No explainer available'}