gapura-ai-api / data /shap_service.py
Muhammad Ridzki Nugraha
Upload folder using huggingface_hub
13c3f2c verified
"""
SHAP Explainability Service for Gapura AI
Provides feature importance explanations for individual predictions
"""
import os
import logging
import pickle
from typing import List, Dict, Any, Optional
import numpy as np
logger = logging.getLogger(__name__)
class ShapExplainer:
"""SHAP-based explainer for model predictions"""
def __init__(self):
self.explainer = None
self.feature_names = []
self.background_data = None
self._load_explainer()
def _load_explainer(self):
"""Load or create SHAP explainer"""
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
explainer_path = os.path.join(
base_dir, "models", "regression", "shap_explainer.pkl"
)
if os.path.exists(explainer_path):
try:
with open(explainer_path, "rb") as f:
explainer_data = pickle.load(f)
self.explainer = explainer_data.get("explainer")
self.feature_names = explainer_data.get("feature_names", [])
logger.info("SHAP explainer loaded successfully")
except Exception as e:
logger.warning(f"Failed to load SHAP explainer: {e}")
self.explainer = None
def create_explainer(
self, model, X_background: np.ndarray, feature_names: List[str]
):
"""Create and save SHAP explainer"""
try:
import shap
if len(X_background) > 100:
indices = np.random.choice(len(X_background), 100, replace=False)
X_background = X_background[indices]
self.explainer = shap.TreeExplainer(model, X_background)
self.feature_names = feature_names
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
explainer_path = os.path.join(
base_dir, "models", "regression", "shap_explainer.pkl"
)
os.makedirs(os.path.dirname(explainer_path), exist_ok=True)
with open(explainer_path, "wb") as f:
pickle.dump(
{
"explainer": self.explainer,
"feature_names": feature_names,
},
f,
)
logger.info(f"SHAP explainer created and saved to {explainer_path}")
return True
except Exception as e:
logger.error(f"Failed to create SHAP explainer: {e}")
return False
def explain_prediction(self, X: np.ndarray, top_n: int = 10) -> Dict[str, Any]:
"""Generate SHAP explanation for a single prediction"""
if self.explainer is None:
return self._fallback_explanation(X, top_n)
try:
import shap
shap_values = self.explainer.shap_values(X)
if isinstance(shap_values, list):
shap_values = shap_values[0]
shap_values = np.array(shap_values).flatten()
base_value = self.explainer.expected_value
if isinstance(base_value, np.ndarray):
base_value = base_value[0]
feature_contributions = []
for i, (name, shap_val) in enumerate(zip(self.feature_names, shap_values)):
feature_contributions.append(
{
"feature": name,
"shap_value": float(shap_val),
"abs_contribution": abs(float(shap_val)),
"direction": "increases" if shap_val > 0 else "decreases",
}
)
feature_contributions.sort(
key=lambda x: x["abs_contribution"], reverse=True
)
top_features = feature_contributions[:top_n]
positive_factors = [f for f in top_features if f["shap_value"] > 0]
negative_factors = [f for f in top_features if f["shap_value"] < 0]
explanation_text = self._generate_explanation_text(top_features, base_value)
return {
"base_value": float(base_value),
"prediction_explained": True,
"top_factors": top_features,
"positive_factors": positive_factors,
"negative_factors": negative_factors,
"explanation": explanation_text,
"shap_values": {
f["feature"]: f["shap_value"] for f in feature_contributions
},
}
except Exception as e:
logger.error(f"SHAP explanation failed: {e}")
return self._fallback_explanation(X, top_n)
def _fallback_explanation(self, X: np.ndarray, top_n: int) -> Dict[str, Any]:
"""Fallback explanation when SHAP is not available"""
return {
"base_value": 2.0,
"prediction_explained": False,
"top_factors": [],
"positive_factors": [],
"negative_factors": [],
"explanation": "SHAP explainer not available. Train model to enable explainability.",
"shap_values": {},
}
def _generate_explanation_text(
self, top_features: List[Dict], base_value: float
) -> str:
"""Generate human-readable explanation"""
if not top_features:
return "No significant factors identified."
explanations = []
for factor in top_features[:3]:
direction = factor["direction"]
feature = factor["feature"].replace("_", " ").replace(" encoded", "")
if direction == "increases":
explanations.append(f"Higher {feature} increases resolution time")
else:
explanations.append(f"Higher {feature} decreases resolution time")
return ". ".join(explanations) + "."
_shap_explainer: Optional[ShapExplainer] = None
def get_shap_explainer() -> ShapExplainer:
"""Get singleton SHAP explainer instance"""
global _shap_explainer
if _shap_explainer is None:
_shap_explainer = ShapExplainer()
return _shap_explainer