"""Explainability helpers for DDI models. Supports SHAP explanations for sklearn-compatible models and rationale summaries. """ from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, List import joblib import numpy as np try: import shap except Exception: # pragma: no cover shap = None # type: ignore def summarize_top_features(values: np.ndarray, feature_names: List[str], top_k: int = 8) -> List[Dict[str, Any]]: idx = np.argsort(np.abs(values))[::-1][:top_k] return [ { 'feature': feature_names[i] if i < len(feature_names) else f'f_{i}', 'contribution': float(values[i]), 'abs_contribution': float(abs(values[i])), } for i in idx ] def confidence_reasoning(confidence: float, severe_probability: float, top2_margin: float) -> Dict[str, Any]: if confidence < 0.55: band = 'low' rationale = 'Prediction confidence is low; escalate to expert review before clinical action.' elif confidence < 0.75: band = 'medium' rationale = 'Prediction confidence is moderate; use as advisory signal with monitoring.' else: band = 'high' rationale = 'Prediction confidence is high; still validate against patient-specific contraindications.' if severe_probability >= 0.45 and top2_margin <= 0.15: rationale += ' Severe interaction probability is elevated with narrow class margin, suggesting conservative handling.' return { 'confidence_band': band, 'rationale': rationale, 'severe_probability': float(severe_probability), 'top2_margin': float(top2_margin), } def interaction_rationale_summary( predicted_label: str, confidence: float, top_features: List[Dict[str, Any]], severe_probability: float, top2_margin: float, ) -> str: strongest = ', '.join([str(x['feature']) for x in top_features[:3]]) if top_features else 'no dominant features' conf = confidence_reasoning(confidence, severe_probability, top2_margin) return ( f'Predicted {predicted_label} interaction with confidence {confidence:.3f}. ' f'Key contributing feature groups: {strongest}. ' f'{conf["rationale"]}' ) def explain_with_shap(model_path: Path, X_sample: np.ndarray, feature_names: List[str], out_path: Path) -> Dict[str, Any]: if shap is None: raise RuntimeError('shap is not installed; install shap to enable explainability reports') model = joblib.load(model_path) explainer = shap.Explainer(model, X_sample) shap_values = explainer(X_sample) # For multiclass output, use class-argmax contribution vector per sample. reports: List[Dict[str, Any]] = [] raw_values = shap_values.values probs = model.predict_proba(X_sample) for i in range(X_sample.shape[0]): pred_class = int(np.argmax(probs[i])) if raw_values.ndim == 3: vec = raw_values[i, :, pred_class] else: vec = raw_values[i] reports.append( { 'sample_index': i, 'predicted_class': pred_class, 'confidence': float(np.max(probs[i])), 'top_features': summarize_top_features(np.array(vec), feature_names), } ) payload = { 'num_samples': int(X_sample.shape[0]), 'reports': reports, } out_path.parent.mkdir(parents=True, exist_ok=True) out_path.write_text(json.dumps(payload, indent=2), encoding='utf-8') return payload