File size: 3,584 Bytes
d29b763
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Explainability helpers for DDI models.

Supports SHAP explanations for sklearn-compatible models and rationale summaries.
"""
from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, List

import joblib
import numpy as np

try:
    import shap
except Exception:  # pragma: no cover
    shap = None  # type: ignore


def summarize_top_features(values: np.ndarray, feature_names: List[str], top_k: int = 8) -> List[Dict[str, Any]]:
    idx = np.argsort(np.abs(values))[::-1][:top_k]
    return [
        {
            'feature': feature_names[i] if i < len(feature_names) else f'f_{i}',
            'contribution': float(values[i]),
            'abs_contribution': float(abs(values[i])),
        }
        for i in idx
    ]


def confidence_reasoning(confidence: float, severe_probability: float, top2_margin: float) -> Dict[str, Any]:
    if confidence < 0.55:
        band = 'low'
        rationale = 'Prediction confidence is low; escalate to expert review before clinical action.'
    elif confidence < 0.75:
        band = 'medium'
        rationale = 'Prediction confidence is moderate; use as advisory signal with monitoring.'
    else:
        band = 'high'
        rationale = 'Prediction confidence is high; still validate against patient-specific contraindications.'

    if severe_probability >= 0.45 and top2_margin <= 0.15:
        rationale += ' Severe interaction probability is elevated with narrow class margin, suggesting conservative handling.'

    return {
        'confidence_band': band,
        'rationale': rationale,
        'severe_probability': float(severe_probability),
        'top2_margin': float(top2_margin),
    }


def interaction_rationale_summary(
    predicted_label: str,
    confidence: float,
    top_features: List[Dict[str, Any]],
    severe_probability: float,
    top2_margin: float,
) -> str:
    strongest = ', '.join([str(x['feature']) for x in top_features[:3]]) if top_features else 'no dominant features'
    conf = confidence_reasoning(confidence, severe_probability, top2_margin)
    return (
        f'Predicted {predicted_label} interaction with confidence {confidence:.3f}. '
        f'Key contributing feature groups: {strongest}. '
        f'{conf["rationale"]}'
    )


def explain_with_shap(model_path: Path, X_sample: np.ndarray, feature_names: List[str], out_path: Path) -> Dict[str, Any]:
    if shap is None:
        raise RuntimeError('shap is not installed; install shap to enable explainability reports')

    model = joblib.load(model_path)
    explainer = shap.Explainer(model, X_sample)
    shap_values = explainer(X_sample)

    # For multiclass output, use class-argmax contribution vector per sample.
    reports: List[Dict[str, Any]] = []
    raw_values = shap_values.values
    probs = model.predict_proba(X_sample)

    for i in range(X_sample.shape[0]):
        pred_class = int(np.argmax(probs[i]))
        if raw_values.ndim == 3:
            vec = raw_values[i, :, pred_class]
        else:
            vec = raw_values[i]
        reports.append(
            {
                'sample_index': i,
                'predicted_class': pred_class,
                'confidence': float(np.max(probs[i])),
                'top_features': summarize_top_features(np.array(vec), feature_names),
            }
        )

    payload = {
        'num_samples': int(X_sample.shape[0]),
        'reports': reports,
    }

    out_path.parent.mkdir(parents=True, exist_ok=True)
    out_path.write_text(json.dumps(payload, indent=2), encoding='utf-8')
    return payload