flood-vulnerability / explainability.py
adema5051's picture
Upload 10 files
a359779 verified
# explainability.py
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
import shap
import pickle
import os
class VulnerabilityExplainer:
"""
SHAP-based explainer for flood vulnerability scores
"""
def __init__(self, model_path='models/rf_explainer.pkl'):
self.model = None
self.explainer = None
self.model_path = model_path
self.feature_names = [
'proximity_score',
'tpi_score',
'slope_score',
'height_score',
'elevation'
]
def train(self, training_data_path='training_data.csv'):
"""
Train surrogate RF model on existing vulnerability assessments
"""
print(f"Loading training data from {training_data_path}...")
df = pd.read_csv(training_data_path)
missing_cols = [col for col in self.feature_names if col not in df.columns]
if missing_cols:
raise ValueError(f"Missing columns in training data: {missing_cols}")
if 'vulnerability_index' not in df.columns:
raise ValueError("Training data must have 'vulnerability_index' column")
X = df[self.feature_names]
y = df['vulnerability_index']
print(f"Training Random Forest on {len(df)} samples...")
self.model = RandomForestRegressor(
n_estimators=100,
max_depth=10,
random_state=42,
n_jobs=-1
)
self.model.fit(X, y)
print("Creating SHAP explainer...")
self.explainer = shap.TreeExplainer(self.model)
os.makedirs(os.path.dirname(self.model_path), exist_ok=True)
with open(self.model_path, 'wb') as f:
pickle.dump({
'model': self.model,
'explainer': self.explainer,
'feature_names': self.feature_names
}, f)
r2_score = self.model.score(X, y)
print(f"✅ Model trained successfully!")
print(f" R² score: {r2_score:.3f}")
print(f" Saved to: {self.model_path}")
def load(self):
"""Load trained model"""
if os.path.exists(self.model_path):
try:
with open(self.model_path, 'rb') as f:
data = pickle.load(f)
self.model = data['model']
self.explainer = data['explainer']
self.feature_names = data['feature_names']
print(f"✅ SHAP model loaded from {self.model_path}")
return True
except Exception as e:
print(f"⚠️ Failed to load SHAP model: {e}")
return False
else:
print(f"⚠️ SHAP model not found at {self.model_path}")
return False
def explain(self, features_dict):
"""
Generate SHAP explanation for a single assessment
"""
if not self.explainer:
if not self.load():
return None
try:
X = pd.DataFrame([features_dict])[self.feature_names]
except KeyError as e:
print(f"Missing feature in input: {e}")
return None
shap_values = self.explainer.shap_values(X)
if isinstance(shap_values, list):
shap_values = shap_values[0]
shap_values = np.array(shap_values).astype(float).flatten()
base_value = float(np.array(self.explainer.expected_value).mean())
contributions = list(zip(self.feature_names, shap_values))
contributions.sort(key=lambda x: abs(x[1]), reverse=True)
total_impact = sum(abs(v) for _, v in contributions)
explanations = []
for name, value in contributions:
value = float(value)
pct = (abs(value) / total_impact * 100) if total_impact > 0 else 0
direction = "increases" if value > 0 else "decreases"
explanations.append({
'factor': self._humanize_feature(name),
'contribution_pct': round(pct, 1),
'direction': direction,
'shap_value': round(value, 3)
})
return {
'base_vulnerability': round(base_value, 3),
'predicted_vulnerability': round(base_value + sum(shap_values), 3),
'explanations': explanations,
'top_risk_driver': explanations[0]['factor'] if explanations else None
}
def _humanize_feature(self, feature_name):
"""Convert feature names to readable descriptions"""
labels = {
'proximity_score': 'Distance to water',
'tpi_score': 'Topographic position (valley vs. ridge)',
'slope_score': 'Terrain slope',
'height_score': 'Building height and basement',
'elevation': 'Elevation above sea level'
}
return labels.get(feature_name, feature_name)
if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
training_file = sys.argv[1]
else:
training_file = 'training_data.csv'
if not os.path.exists(training_file):
print(f"❌ Training data not found: {training_file}")
sys.exit(1)
explainer = VulnerabilityExplainer()
explainer.train(training_file)
print("\n✅ SHAP explainer ready!")