ddi / src /validation /explainability.py
github-actions[bot]
Deploy from GitHub Actions (fb28c05c54cf19184fc3f14f1bf3297ba5749ea2)
d29b763
"""Explainability helpers for DDI models.
Supports SHAP explanations for sklearn-compatible models and rationale summaries.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any, Dict, List
import joblib
import numpy as np
try:
import shap
except Exception: # pragma: no cover
shap = None # type: ignore
def summarize_top_features(values: np.ndarray, feature_names: List[str], top_k: int = 8) -> List[Dict[str, Any]]:
idx = np.argsort(np.abs(values))[::-1][:top_k]
return [
{
'feature': feature_names[i] if i < len(feature_names) else f'f_{i}',
'contribution': float(values[i]),
'abs_contribution': float(abs(values[i])),
}
for i in idx
]
def confidence_reasoning(confidence: float, severe_probability: float, top2_margin: float) -> Dict[str, Any]:
if confidence < 0.55:
band = 'low'
rationale = 'Prediction confidence is low; escalate to expert review before clinical action.'
elif confidence < 0.75:
band = 'medium'
rationale = 'Prediction confidence is moderate; use as advisory signal with monitoring.'
else:
band = 'high'
rationale = 'Prediction confidence is high; still validate against patient-specific contraindications.'
if severe_probability >= 0.45 and top2_margin <= 0.15:
rationale += ' Severe interaction probability is elevated with narrow class margin, suggesting conservative handling.'
return {
'confidence_band': band,
'rationale': rationale,
'severe_probability': float(severe_probability),
'top2_margin': float(top2_margin),
}
def interaction_rationale_summary(
predicted_label: str,
confidence: float,
top_features: List[Dict[str, Any]],
severe_probability: float,
top2_margin: float,
) -> str:
strongest = ', '.join([str(x['feature']) for x in top_features[:3]]) if top_features else 'no dominant features'
conf = confidence_reasoning(confidence, severe_probability, top2_margin)
return (
f'Predicted {predicted_label} interaction with confidence {confidence:.3f}. '
f'Key contributing feature groups: {strongest}. '
f'{conf["rationale"]}'
)
def explain_with_shap(model_path: Path, X_sample: np.ndarray, feature_names: List[str], out_path: Path) -> Dict[str, Any]:
if shap is None:
raise RuntimeError('shap is not installed; install shap to enable explainability reports')
model = joblib.load(model_path)
explainer = shap.Explainer(model, X_sample)
shap_values = explainer(X_sample)
# For multiclass output, use class-argmax contribution vector per sample.
reports: List[Dict[str, Any]] = []
raw_values = shap_values.values
probs = model.predict_proba(X_sample)
for i in range(X_sample.shape[0]):
pred_class = int(np.argmax(probs[i]))
if raw_values.ndim == 3:
vec = raw_values[i, :, pred_class]
else:
vec = raw_values[i]
reports.append(
{
'sample_index': i,
'predicted_class': pred_class,
'confidence': float(np.max(probs[i])),
'top_features': summarize_top_features(np.array(vec), feature_names),
}
)
payload = {
'num_samples': int(X_sample.shape[0]),
'reports': reports,
}
out_path.parent.mkdir(parents=True, exist_ok=True)
out_path.write_text(json.dumps(payload, indent=2), encoding='utf-8')
return payload