File size: 5,375 Bytes
a359779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# explainability.py

import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
import shap
import pickle
import os


class VulnerabilityExplainer:
    """
    SHAP-based explainer for flood vulnerability scores
    """

    def __init__(self, model_path='models/rf_explainer.pkl'):
        self.model = None
        self.explainer = None
        self.model_path = model_path
        self.feature_names = [
            'proximity_score',
            'tpi_score',
            'slope_score',
            'height_score',
            'elevation'
            
        ]

    def train(self, training_data_path='training_data.csv'):
        """
        Train surrogate RF model on existing vulnerability assessments
        """
        print(f"Loading training data from {training_data_path}...")
        df = pd.read_csv(training_data_path)

        missing_cols = [col for col in self.feature_names if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Missing columns in training data: {missing_cols}")

        if 'vulnerability_index' not in df.columns:
            raise ValueError("Training data must have 'vulnerability_index' column")

        X = df[self.feature_names]
        y = df['vulnerability_index']

        print(f"Training Random Forest on {len(df)} samples...")

        self.model = RandomForestRegressor(
            n_estimators=100,
            max_depth=10,
            random_state=42,
            n_jobs=-1
        )
        self.model.fit(X, y)

        print("Creating SHAP explainer...")
        self.explainer = shap.TreeExplainer(self.model)

        os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
        with open(self.model_path, 'wb') as f:
            pickle.dump({
                'model': self.model,
                'explainer': self.explainer,
                'feature_names': self.feature_names
            }, f)

        r2_score = self.model.score(X, y)
        print(f"✅ Model trained successfully!")
        print(f"   R² score: {r2_score:.3f}")
        print(f"   Saved to: {self.model_path}")

    def load(self):
        """Load trained model"""
        if os.path.exists(self.model_path):
            try:
                with open(self.model_path, 'rb') as f:
                    data = pickle.load(f)
                    self.model = data['model']
                    self.explainer = data['explainer']
                    self.feature_names = data['feature_names']
                print(f"✅ SHAP model loaded from {self.model_path}")
                return True
            except Exception as e:
                print(f"⚠️ Failed to load SHAP model: {e}")
                return False
        else:
            print(f"⚠️ SHAP model not found at {self.model_path}")
            return False

    def explain(self, features_dict):
        """
        Generate SHAP explanation for a single assessment
        """
        if not self.explainer:
            if not self.load():
                return None

        try:
            X = pd.DataFrame([features_dict])[self.feature_names]
        except KeyError as e:
            print(f"Missing feature in input: {e}")
            return None

        shap_values = self.explainer.shap_values(X)
        if isinstance(shap_values, list):
            shap_values = shap_values[0]

        shap_values = np.array(shap_values).astype(float).flatten()
        base_value = float(np.array(self.explainer.expected_value).mean())

        contributions = list(zip(self.feature_names, shap_values))
        contributions.sort(key=lambda x: abs(x[1]), reverse=True)
        total_impact = sum(abs(v) for _, v in contributions)

        explanations = []
        for name, value in contributions:
            value = float(value)
            pct = (abs(value) / total_impact * 100) if total_impact > 0 else 0
            direction = "increases" if value > 0 else "decreases"
            explanations.append({
                'factor': self._humanize_feature(name),
                'contribution_pct': round(pct, 1),
                'direction': direction,
                'shap_value': round(value, 3)
            })

        return {
            'base_vulnerability': round(base_value, 3),
            'predicted_vulnerability': round(base_value + sum(shap_values), 3),
            'explanations': explanations,
            'top_risk_driver': explanations[0]['factor'] if explanations else None
        }

    def _humanize_feature(self, feature_name):
        """Convert feature names to readable descriptions"""
        labels = {
            'proximity_score': 'Distance to water',
            'tpi_score': 'Topographic position (valley vs. ridge)',
            'slope_score': 'Terrain slope',
            'height_score': 'Building height and basement',
            'elevation': 'Elevation above sea level'            
        }
        return labels.get(feature_name, feature_name)


if __name__ == "__main__":
    import sys

    if len(sys.argv) > 1:
        training_file = sys.argv[1]
    else:
        training_file = 'training_data.csv'

    if not os.path.exists(training_file):
        print(f"❌ Training data not found: {training_file}")
      
        sys.exit(1)

    explainer = VulnerabilityExplainer()
    explainer.train(training_file)

    print("\n✅ SHAP explainer ready!")