File size: 4,160 Bytes
dc99b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b136ae1
dc99b5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import sys
from pathlib import Path
import numpy as np
import shap
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestRegressor

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))


FEATURE_NAMES = [
    "MolWt",
    "LogP", 
    "TPSA",
    "NumHDonors",
    "NumHAcceptors",
    "NumRotatableBonds",
    "RingCount"
]


class ModelInterpreter:
    """Provides model interpretability via SHAP and uncertainty quantification."""
    
    def __init__(self, model: RandomForestRegressor, X_train: np.ndarray):
        """
        Initialize the interpreter with a trained model and training data.
        
        Args:
            model: Trained RandomForestRegressor
            X_train: Training features used to fit the explainer
        """
        self.model = model
        self.explainer = shap.TreeExplainer(model)
        # Use a subset of training data for faster SHAP computation
        self.X_background = shap.sample(X_train, min(100, len(X_train)))
    
    def get_feature_importance(self) -> dict:
        """Get global feature importance from the random forest model."""
        importance = self.model.feature_importances_
        return {
            name: float(imp) 
            for name, imp in zip(FEATURE_NAMES, importance)
        }
    
    def explain_prediction(self, X: np.ndarray):
        """
        Explain a single prediction using SHAP values.
        
        Args:
            X: Feature vector (1, n_features)
            
        Returns:
            dict with SHAP values and base value
        """
        shap_values = self.explainer.shap_values(X)
        return {
            "shap_values": {
                name: float(val) 
                for name, val in zip(FEATURE_NAMES, shap_values[0])
            },
            "base_value": float(self.explainer.expected_value),
            "prediction": float(self.model.predict(X)[0])
        }
    
    def get_prediction_interval(self, X: np.ndarray, confidence=0.95) -> tuple:
        """
        Estimate prediction interval using individual tree predictions.
        
        Args:
            X: Feature vector (1, n_features)
            confidence: Confidence level (default 0.95)
            
        Returns:
            (lower_bound, upper_bound, std_dev)
        """
        # Get predictions from all trees
        tree_predictions = np.array([
            tree.predict(X)[0] 
            for tree in self.model.estimators_
        ])
        
        mean_pred = tree_predictions.mean()
        std_pred = tree_predictions.std()
        
        # Calculate confidence interval
        alpha = 1 - confidence
        lower = np.percentile(tree_predictions, alpha/2 * 100)
        upper = np.percentile(tree_predictions, (1 - alpha/2) * 100)
        
        return float(lower), float(upper), float(std_pred)
    
    def plot_shap_waterfall(self, X: np.ndarray, feature_values: dict) -> plt.Figure:
        """
        Create a SHAP waterfall plot showing feature contributions.
        
        Args:
            X: Feature vector (1, n_features)
            feature_values: Dict mapping feature names to values
            
        Returns:
            matplotlib Figure
        """
        shap_values = self.explainer.shap_values(X)
        base_value = float(self.explainer.expected_value)
        
        # Create waterfall plot
        fig, ax = plt.subplots(figsize=(10, 6))
        
        # Sort by absolute SHAP value
        indices = np.argsort(np.abs(shap_values[0]))[::-1]
        
        y_pos = np.arange(len(FEATURE_NAMES))
        colors = ['#ff0051' if val > 0 else '#008bfb' for val in shap_values[0][indices]]
        
        ax.barh(y_pos, shap_values[0][indices], color=colors)
        ax.set_yticks(y_pos)
        ax.set_yticklabels([f"{FEATURE_NAMES[i]} = {X[0][i]:.2f}" for i in indices])
        ax.set_xlabel('SHAP value (impact on prediction)')
        ax.set_title(f'Feature Contributions to Prediction\nBase value: {base_value:.3f}')
        ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
        
        plt.tight_layout()
        return fig