Spaces:
Sleeping
Sleeping
| # explainability.py | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.ensemble import RandomForestRegressor | |
| import shap | |
| import pickle | |
| import os | |
| class VulnerabilityExplainer: | |
| """ | |
| SHAP-based explainer for flood vulnerability scores | |
| """ | |
| def __init__(self, model_path='models/rf_explainer.pkl'): | |
| self.model = None | |
| self.explainer = None | |
| self.model_path = model_path | |
| self.feature_names = [ | |
| 'proximity_score', | |
| 'tpi_score', | |
| 'slope_score', | |
| 'height_score', | |
| 'elevation' | |
| ] | |
| def train(self, training_data_path='training_data.csv'): | |
| """ | |
| Train surrogate RF model on existing vulnerability assessments | |
| """ | |
| print(f"Loading training data from {training_data_path}...") | |
| df = pd.read_csv(training_data_path) | |
| missing_cols = [col for col in self.feature_names if col not in df.columns] | |
| if missing_cols: | |
| raise ValueError(f"Missing columns in training data: {missing_cols}") | |
| if 'vulnerability_index' not in df.columns: | |
| raise ValueError("Training data must have 'vulnerability_index' column") | |
| X = df[self.feature_names] | |
| y = df['vulnerability_index'] | |
| print(f"Training Random Forest on {len(df)} samples...") | |
| self.model = RandomForestRegressor( | |
| n_estimators=100, | |
| max_depth=10, | |
| random_state=42, | |
| n_jobs=-1 | |
| ) | |
| self.model.fit(X, y) | |
| print("Creating SHAP explainer...") | |
| self.explainer = shap.TreeExplainer(self.model) | |
| os.makedirs(os.path.dirname(self.model_path), exist_ok=True) | |
| with open(self.model_path, 'wb') as f: | |
| pickle.dump({ | |
| 'model': self.model, | |
| 'explainer': self.explainer, | |
| 'feature_names': self.feature_names | |
| }, f) | |
| r2_score = self.model.score(X, y) | |
| print(f"✅ Model trained successfully!") | |
| print(f" R² score: {r2_score:.3f}") | |
| print(f" Saved to: {self.model_path}") | |
| def load(self): | |
| """Load trained model""" | |
| if os.path.exists(self.model_path): | |
| try: | |
| with open(self.model_path, 'rb') as f: | |
| data = pickle.load(f) | |
| self.model = data['model'] | |
| self.explainer = data['explainer'] | |
| self.feature_names = data['feature_names'] | |
| print(f"✅ SHAP model loaded from {self.model_path}") | |
| return True | |
| except Exception as e: | |
| print(f"⚠️ Failed to load SHAP model: {e}") | |
| return False | |
| else: | |
| print(f"⚠️ SHAP model not found at {self.model_path}") | |
| return False | |
| def explain(self, features_dict): | |
| """ | |
| Generate SHAP explanation for a single assessment | |
| """ | |
| if not self.explainer: | |
| if not self.load(): | |
| return None | |
| try: | |
| X = pd.DataFrame([features_dict])[self.feature_names] | |
| except KeyError as e: | |
| print(f"Missing feature in input: {e}") | |
| return None | |
| shap_values = self.explainer.shap_values(X) | |
| if isinstance(shap_values, list): | |
| shap_values = shap_values[0] | |
| shap_values = np.array(shap_values).astype(float).flatten() | |
| base_value = float(np.array(self.explainer.expected_value).mean()) | |
| contributions = list(zip(self.feature_names, shap_values)) | |
| contributions.sort(key=lambda x: abs(x[1]), reverse=True) | |
| total_impact = sum(abs(v) for _, v in contributions) | |
| explanations = [] | |
| for name, value in contributions: | |
| value = float(value) | |
| pct = (abs(value) / total_impact * 100) if total_impact > 0 else 0 | |
| direction = "increases" if value > 0 else "decreases" | |
| explanations.append({ | |
| 'factor': self._humanize_feature(name), | |
| 'contribution_pct': round(pct, 1), | |
| 'direction': direction, | |
| 'shap_value': round(value, 3) | |
| }) | |
| return { | |
| 'base_vulnerability': round(base_value, 3), | |
| 'predicted_vulnerability': round(base_value + sum(shap_values), 3), | |
| 'explanations': explanations, | |
| 'top_risk_driver': explanations[0]['factor'] if explanations else None | |
| } | |
| def _humanize_feature(self, feature_name): | |
| """Convert feature names to readable descriptions""" | |
| labels = { | |
| 'proximity_score': 'Distance to water', | |
| 'tpi_score': 'Topographic position (valley vs. ridge)', | |
| 'slope_score': 'Terrain slope', | |
| 'height_score': 'Building height and basement', | |
| 'elevation': 'Elevation above sea level' | |
| } | |
| return labels.get(feature_name, feature_name) | |
| if __name__ == "__main__": | |
| import sys | |
| if len(sys.argv) > 1: | |
| training_file = sys.argv[1] | |
| else: | |
| training_file = 'training_data.csv' | |
| if not os.path.exists(training_file): | |
| print(f"❌ Training data not found: {training_file}") | |
| sys.exit(1) | |
| explainer = VulnerabilityExplainer() | |
| explainer.train(training_file) | |
| print("\n✅ SHAP explainer ready!") | |