Spaces:
Running
Running
| """Explainability helpers for DDI models. | |
| Supports SHAP explanations for sklearn-compatible models and rationale summaries. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import joblib | |
| import numpy as np | |
| try: | |
| import shap | |
| except Exception: # pragma: no cover | |
| shap = None # type: ignore | |
| def summarize_top_features(values: np.ndarray, feature_names: List[str], top_k: int = 8) -> List[Dict[str, Any]]: | |
| idx = np.argsort(np.abs(values))[::-1][:top_k] | |
| return [ | |
| { | |
| 'feature': feature_names[i] if i < len(feature_names) else f'f_{i}', | |
| 'contribution': float(values[i]), | |
| 'abs_contribution': float(abs(values[i])), | |
| } | |
| for i in idx | |
| ] | |
| def confidence_reasoning(confidence: float, severe_probability: float, top2_margin: float) -> Dict[str, Any]: | |
| if confidence < 0.55: | |
| band = 'low' | |
| rationale = 'Prediction confidence is low; escalate to expert review before clinical action.' | |
| elif confidence < 0.75: | |
| band = 'medium' | |
| rationale = 'Prediction confidence is moderate; use as advisory signal with monitoring.' | |
| else: | |
| band = 'high' | |
| rationale = 'Prediction confidence is high; still validate against patient-specific contraindications.' | |
| if severe_probability >= 0.45 and top2_margin <= 0.15: | |
| rationale += ' Severe interaction probability is elevated with narrow class margin, suggesting conservative handling.' | |
| return { | |
| 'confidence_band': band, | |
| 'rationale': rationale, | |
| 'severe_probability': float(severe_probability), | |
| 'top2_margin': float(top2_margin), | |
| } | |
| def interaction_rationale_summary( | |
| predicted_label: str, | |
| confidence: float, | |
| top_features: List[Dict[str, Any]], | |
| severe_probability: float, | |
| top2_margin: float, | |
| ) -> str: | |
| strongest = ', '.join([str(x['feature']) for x in top_features[:3]]) if top_features else 'no dominant features' | |
| conf = confidence_reasoning(confidence, severe_probability, top2_margin) | |
| return ( | |
| f'Predicted {predicted_label} interaction with confidence {confidence:.3f}. ' | |
| f'Key contributing feature groups: {strongest}. ' | |
| f'{conf["rationale"]}' | |
| ) | |
| def explain_with_shap(model_path: Path, X_sample: np.ndarray, feature_names: List[str], out_path: Path) -> Dict[str, Any]: | |
| if shap is None: | |
| raise RuntimeError('shap is not installed; install shap to enable explainability reports') | |
| model = joblib.load(model_path) | |
| explainer = shap.Explainer(model, X_sample) | |
| shap_values = explainer(X_sample) | |
| # For multiclass output, use class-argmax contribution vector per sample. | |
| reports: List[Dict[str, Any]] = [] | |
| raw_values = shap_values.values | |
| probs = model.predict_proba(X_sample) | |
| for i in range(X_sample.shape[0]): | |
| pred_class = int(np.argmax(probs[i])) | |
| if raw_values.ndim == 3: | |
| vec = raw_values[i, :, pred_class] | |
| else: | |
| vec = raw_values[i] | |
| reports.append( | |
| { | |
| 'sample_index': i, | |
| 'predicted_class': pred_class, | |
| 'confidence': float(np.max(probs[i])), | |
| 'top_features': summarize_top_features(np.array(vec), feature_names), | |
| } | |
| ) | |
| payload = { | |
| 'num_samples': int(X_sample.shape[0]), | |
| 'reports': reports, | |
| } | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| out_path.write_text(json.dumps(payload, indent=2), encoding='utf-8') | |
| return payload | |