clarindasusan commited on
Commit
560ffc8
·
verified ·
1 Parent(s): e9989cc

Update src/train_model.py

Browse files
Files changed (1) hide show
  1. src/train_model.py +302 -357
src/train_model.py CHANGED
@@ -1,380 +1,325 @@
1
  """
2
- CYCLONE INTENSITY PREDICTION - XGBOOST MODEL
3
- Trains an XGBoost model to predict cyclone wind speed 24 hours ahead
4
- Includes hyperparameter tuning, feature importance, and comprehensive evaluation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  """
6
 
7
- import pandas as pd
8
  import numpy as np
9
- import xgboost as xgb
10
- from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
11
- from sklearn.model_selection import RandomizedSearchCV, KFold
12
- import matplotlib.pyplot as plt
13
- import seaborn as sns
14
- import joblib
15
  import os
16
- from datetime import datetime
 
 
17
 
18
- # Configuration
19
- INPUT_FILE = "data/preprocessed.csv"
20
- MODEL_OUTPUT = "models/xgboost_cyclone_model.pkl"
21
- RESULTS_DIR = "results/"
22
-
23
- # Create directories
24
- os.makedirs("models", exist_ok=True)
25
- os.makedirs(RESULTS_DIR, exist_ok=True)
26
-
27
- print("=" * 80)
28
- print("CYCLONE INTENSITY PREDICTION - XGBOOST TRAINING")
29
- print("=" * 80)
30
-
31
- # ============================================================================
32
- # STEP 1: LOAD PREPROCESSED DATA
33
- # ============================================================================
34
- print("\n📂 STEP 1: Loading preprocessed data...")
35
- print("-" * 80)
36
-
37
- df = pd.read_csv(INPUT_FILE)
38
- print(f"✅ Loaded dataset: {df.shape[0]} rows × {df.shape[1]} columns")
39
-
40
- # ============================================================================
41
- # STEP 2: PREPARE FEATURES AND TARGET
42
- # ============================================================================
43
- print("\n🎯 STEP 2: Preparing features and target...")
44
- print("-" * 80)
45
-
46
- # Define columns to exclude from features
47
- exclude_cols = [
48
- 'DATE_TIME', 'TARGET_24H', 'SPLIT', 'BASIN', 'TECH',
49
- 'CYCLONE_NUMBER', 'STORM_TYPE', # Already one-hot encoded
50
- 'YEAR' # Keep MONTH, HOUR, DOY but not YEAR (prevents overfitting)
51
- ]
52
-
53
- # Get feature columns
54
- feature_cols = [col for col in df.columns if col not in exclude_cols]
55
- target_col = 'TARGET_24H'
56
-
57
- print(f"📊 Total features: {len(feature_cols)}")
58
- print(f"🎯 Target variable: {target_col}")
59
-
60
- # Split by pre-defined split column
61
- train_df = df[df['SPLIT'] == 'train'].copy()
62
- test_df = df[df['SPLIT'] == 'test'].copy()
63
-
64
- X_train = train_df[feature_cols]
65
- y_train = train_df[target_col]
66
- X_test = test_df[feature_cols]
67
- y_test = test_df[target_col]
68
-
69
- print(f"\n✅ Training set: {X_train.shape[0]} samples")
70
- print(f"✅ Test set: {X_test.shape[0]} samples")
71
- print(f" Test set percentage: {len(test_df)/len(df)*100:.1f}%")
72
-
73
- # Check for any remaining NaN values
74
- if X_train.isnull().any().any():
75
- print("\n⚠️ Warning: NaN values found in training features")
76
- print(X_train.isnull().sum()[X_train.isnull().sum() > 0])
77
- print(" Filling NaN with 0...")
78
- X_train = X_train.fillna(0)
79
- X_test = X_test.fillna(0)
80
-
81
- # ============================================================================
82
- # STEP 3: BASELINE MODEL (DEFAULT PARAMETERS)
83
- # ============================================================================
84
- print("\n🤖 STEP 3: Training baseline XGBoost model...")
85
- print("-" * 80)
86
-
87
- baseline_model = xgb.XGBRegressor(
88
- n_estimators=100,
89
- max_depth=6,
90
- learning_rate=0.1,
91
- random_state=42,
92
- n_jobs=-1
93
  )
94
 
95
- baseline_model.fit(X_train, y_train)
96
- y_pred_baseline = baseline_model.predict(X_test)
97
-
98
- # Baseline metrics
99
- mae_baseline = mean_absolute_error(y_test, y_pred_baseline)
100
- rmse_baseline = np.sqrt(mean_squared_error(y_test, y_pred_baseline))
101
- r2_baseline = r2_score(y_test, y_pred_baseline)
102
 
103
- print(f"\n✅ BASELINE MODEL RESULTS:")
104
- print(f" MAE: {mae_baseline:.2f} knots")
105
- print(f" RMSE: {rmse_baseline:.2f} knots")
106
- print(f" R²: {r2_baseline:.4f}")
107
 
108
  # ============================================================================
109
- # STEP 4: HYPERPARAMETER TUNING
110
  # ============================================================================
111
- print("\n⚙️ STEP 4: Hyperparameter tuning with RandomizedSearchCV...")
112
- print("-" * 80)
113
- print(" This may take several minutes...")
114
-
115
- # Define hyperparameter search space
116
- param_distributions = {
117
- 'n_estimators': [100, 200, 300, 500],
118
- 'max_depth': [3, 5, 7, 9, 11],
119
- 'learning_rate': [0.01, 0.05, 0.1, 0.2],
120
- 'subsample': [0.6, 0.8, 1.0],
121
- 'colsample_bytree': [0.6, 0.8, 1.0],
122
- 'min_child_weight': [1, 3, 5],
123
- 'gamma': [0, 0.1, 0.2, 0.3],
124
- 'reg_alpha': [0, 0.1, 0.5, 1.0], # L1 regularization
125
- 'reg_lambda': [0.5, 1.0, 2.0] # L2 regularization
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  }
127
 
128
- # Create base model
129
- xgb_model = xgb.XGBRegressor(random_state=42, n_jobs=-1)
130
-
131
- # RandomizedSearchCV with cross-validation
132
- random_search = RandomizedSearchCV(
133
- xgb_model,
134
- param_distributions=param_distributions,
135
- n_iter=50, # Try 50 random combinations
136
- cv=5, # 5-fold cross-validation
137
- scoring='neg_mean_absolute_error',
138
- n_jobs=-1,
139
- random_state=42,
140
- verbose=1
141
- )
142
-
143
- random_search.fit(X_train, y_train)
144
-
145
- print(f"\n✅ Best hyperparameters found:")
146
- for param, value in random_search.best_params_.items():
147
- print(f" {param}: {value}")
148
-
149
- # Get best model
150
- best_model = random_search.best_estimator_
151
 
152
  # ============================================================================
153
- # STEP 5: EVALUATE BEST MODEL
154
  # ============================================================================
155
- print("\n📊 STEP 5: Evaluating best model...")
156
- print("-" * 80)
157
-
158
- # Predictions
159
- y_pred_train = best_model.predict(X_train)
160
- y_pred_test = best_model.predict(X_test)
161
-
162
- # Training metrics
163
- mae_train = mean_absolute_error(y_train, y_pred_train)
164
- rmse_train = np.sqrt(mean_squared_error(y_train, y_pred_train))
165
- r2_train = r2_score(y_train, y_pred_train)
166
-
167
- # Test metrics
168
- mae_test = mean_absolute_error(y_test, y_pred_test)
169
- rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
170
- r2_test = r2_score(y_test, y_pred_test)
171
-
172
- print(f"\n✅ OPTIMIZED MODEL RESULTS:")
173
- print(f"\n TRAINING SET:")
174
- print(f" MAE: {mae_train:.2f} knots")
175
- print(f" RMSE: {rmse_train:.2f} knots")
176
- print(f" R²: {r2_train:.4f}")
177
- print(f"\n TEST SET:")
178
- print(f" MAE: {mae_test:.2f} knots")
179
- print(f" RMSE: {rmse_test:.2f} knots")
180
- print(f" R²: {r2_test:.4f}")
181
-
182
- # Improvement over baseline
183
- print(f"\n 📈 IMPROVEMENT OVER BASELINE:")
184
- print(f" MAE improvement: {mae_baseline - mae_test:.2f} knots ({(mae_baseline-mae_test)/mae_baseline*100:.1f}%)")
185
- print(f" RMSE improvement: {rmse_baseline - rmse_test:.2f} knots ({(rmse_baseline-rmse_test)/rmse_baseline*100:.1f}%)")
186
 
187
- # ============================================================================
188
- # STEP 6: FEATURE IMPORTANCE
189
- # ============================================================================
190
- print("\n📊 STEP 6: Analyzing feature importance...")
191
- print("-" * 80)
192
-
193
- # Get feature importance
194
- feature_importance = pd.DataFrame({
195
- 'feature': feature_cols,
196
- 'importance': best_model.feature_importances_
197
- }).sort_values('importance', ascending=False)
198
-
199
- print(f"\n✅ TOP 20 MOST IMPORTANT FEATURES:")
200
- print(feature_importance.head(20).to_string(index=False))
201
-
202
- # Save full feature importance
203
- feature_importance.to_csv(f"{RESULTS_DIR}feature_importance.csv", index=False)
204
-
205
- # Plot feature importance
206
- plt.figure(figsize=(12, 8))
207
- top_features = feature_importance.head(20)
208
- plt.barh(range(len(top_features)), top_features['importance'])
209
- plt.yticks(range(len(top_features)), top_features['feature'])
210
- plt.xlabel('Importance Score')
211
- plt.title('Top 20 Feature Importance - XGBoost Model')
212
- plt.gca().invert_yaxis()
213
- plt.tight_layout()
214
- plt.savefig(f"{RESULTS_DIR}feature_importance.png", dpi=300, bbox_inches='tight')
215
- print(f"\n Saved: {RESULTS_DIR}feature_importance.png")
216
- plt.close()
217
-
218
- # ============================================================================
219
- # STEP 7: PREDICTION ANALYSIS
220
- # ============================================================================
221
- print("\n📈 STEP 7: Analyzing predictions...")
222
- print("-" * 80)
223
-
224
- # Create results dataframe
225
- results_df = pd.DataFrame({
226
- 'actual': y_test,
227
- 'predicted': y_pred_test,
228
- 'error': y_test - y_pred_test,
229
- 'abs_error': np.abs(y_test - y_pred_test)
230
- })
231
-
232
- # Error statistics by intensity ranges
233
- print(f"\n✅ ERROR ANALYSIS BY INTENSITY RANGE:")
234
- intensity_bins = [0, 34, 64, 96, 200]
235
- intensity_labels = ['Tropical Depression (<34kt)',
236
- 'Tropical Storm (34-63kt)',
237
- 'Hurricane/Cyclone (64-95kt)',
238
- 'Major Hurricane (>95kt)']
239
-
240
- results_df['intensity_category'] = pd.cut(results_df['actual'],
241
- bins=intensity_bins,
242
- labels=intensity_labels)
243
-
244
- for category in intensity_labels:
245
- category_data = results_df[results_df['intensity_category'] == category]
246
- if len(category_data) > 0:
247
- print(f"\n {category}:")
248
- print(f" Samples: {len(category_data)}")
249
- print(f" MAE: {category_data['abs_error'].mean():.2f} knots")
250
- print(f" Max Error: {category_data['abs_error'].max():.2f} knots")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
- # ============================================================================
253
- # STEP 8: VISUALIZATION
254
- # ============================================================================
255
- print("\n📊 STEP 8: Creating visualizations...")
256
- print("-" * 80)
257
-
258
- # 1. Actual vs Predicted scatter plot
259
- plt.figure(figsize=(10, 8))
260
- plt.scatter(y_test, y_pred_test, alpha=0.5, s=30)
261
- plt.plot([y_test.min(), y_test.max()],
262
- [y_test.min(), y_test.max()],
263
- 'r--', lw=2, label='Perfect Prediction')
264
- plt.xlabel('Actual Wind Speed (knots)', fontsize=12)
265
- plt.ylabel('Predicted Wind Speed (knots)', fontsize=12)
266
- plt.title(f'XGBoost: Actual vs Predicted\nMAE: {mae_test:.2f} kt, RMSE: {rmse_test:.2f} kt, R²: {r2_test:.3f}',
267
- fontsize=14)
268
- plt.legend()
269
- plt.grid(True, alpha=0.3)
270
- plt.tight_layout()
271
- plt.savefig(f"{RESULTS_DIR}actual_vs_predicted.png", dpi=300, bbox_inches='tight')
272
- print(f"✅ Saved: {RESULTS_DIR}actual_vs_predicted.png")
273
- plt.close()
274
-
275
- # 2. Error distribution
276
- plt.figure(figsize=(12, 5))
277
-
278
- plt.subplot(1, 2, 1)
279
- plt.hist(results_df['error'], bins=50, edgecolor='black', alpha=0.7)
280
- plt.xlabel('Prediction Error (knots)', fontsize=11)
281
- plt.ylabel('Frequency', fontsize=11)
282
- plt.title('Error Distribution', fontsize=12)
283
- plt.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero Error')
284
- plt.legend()
285
- plt.grid(True, alpha=0.3)
286
-
287
- plt.subplot(1, 2, 2)
288
- plt.boxplot([results_df[results_df['intensity_category'] == cat]['abs_error'].dropna()
289
- for cat in intensity_labels],
290
- labels=['TD', 'TS', 'Hurricane', 'Major'])
291
- plt.ylabel('Absolute Error (knots)', fontsize=11)
292
- plt.xlabel('Storm Intensity Category', fontsize=11)
293
- plt.title('Error by Storm Intensity', fontsize=12)
294
- plt.xticks(rotation=45, ha='right')
295
- plt.grid(True, alpha=0.3, axis='y')
296
-
297
- plt.tight_layout()
298
- plt.savefig(f"{RESULTS_DIR}error_analysis.png", dpi=300, bbox_inches='tight')
299
- print(f"✅ Saved: {RESULTS_DIR}error_analysis.png")
300
- plt.close()
301
-
302
- # 3. Learning curve (training history)
303
- print(f"✅ Saved: {RESULTS_DIR}learning_curve.png")
304
 
305
  # ============================================================================
306
- # STEP 9: SAVE MODEL
307
  # ============================================================================
308
- print("\n💾 STEP 9: Saving trained model...")
309
- print("-" * 80)
310
-
311
- # Save the model
312
- joblib.dump(best_model, MODEL_OUTPUT)
313
- print(f"✅ Model saved: {MODEL_OUTPUT}")
314
-
315
- # Save feature names
316
- feature_names_file = "models/feature_names.txt"
317
- with open(feature_names_file, 'w') as f:
318
- for feature in feature_cols:
319
- f.write(f"{feature}\n")
320
- print(f"✅ Feature names saved: {feature_names_file}")
321
-
322
- # Save model metadata
323
- metadata = {
324
- 'training_date': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
325
- 'n_features': len(feature_cols),
326
- 'n_train_samples': len(X_train),
327
- 'n_test_samples': len(X_test),
328
- 'test_mae': float(mae_test),
329
- 'test_rmse': float(rmse_test),
330
- 'test_r2': float(r2_test),
331
- 'best_params': random_search.best_params_
332
- }
333
 
334
- metadata_df = pd.DataFrame([metadata])
335
- metadata_df.to_csv(f"{RESULTS_DIR}model_metadata.csv", index=False)
336
- print(f"✅ Metadata saved: {RESULTS_DIR}model_metadata.csv")
337
-
338
- # ============================================================================
339
- # STEP 10: SUMMARY REPORT
340
- # ============================================================================
341
- print("\n" + "=" * 80)
342
- print("🎉 TRAINING COMPLETE!")
343
- print("=" * 80)
344
-
345
- print(f"""
346
- 📊 FINAL MODEL PERFORMANCE:
347
- Test MAE: {mae_test:.2f} knots
348
- • Test RMSE: {rmse_test:.2f} knots
349
- • Test R²: {r2_test:.4f}
350
-
351
- 📁 OUTPUT FILES:
352
- • Model: {MODEL_OUTPUT}
353
- • Feature names: {feature_names_file}
354
- • Feature importance: {RESULTS_DIR}feature_importance.csv
355
- • Visualizations: {RESULTS_DIR}*.png
356
- • Metadata: {RESULTS_DIR}model_metadata.csv
357
-
358
- 🎯 MODEL INTERPRETATION:
359
- • MAE of {mae_test:.2f} kt means on average, predictions are off by ~{mae_test:.0f} knots
360
- • For a 100kt cyclone, expect ±{mae_test:.0f} kt error range
361
- • R² of {r2_test:.4f} means model explains {r2_test*100:.1f}% of variance
362
-
363
- 🚀 NEXT STEPS:
364
- 1. Use this model for real-time predictions
365
- 2. Monitor performance on new cyclones
366
- 3. Retrain with more recent data periodically
367
- 4. Consider ensemble with other models (Random Forest, Neural Networks)
368
-
369
- 💡 USAGE:
370
- ```python
371
- import joblib
372
- model = joblib.load('{MODEL_OUTPUT}')
373
-
374
- # Make prediction for new cyclone data
375
- prediction = model.predict(new_cyclone_features)
376
- print(f"Predicted wind speed in 24h: {{prediction[0]:.1f}} knots")
377
- ```
378
- """)
379
-
380
- print("=" * 80)
 
1
  """
2
+ train_model.py
3
+ ==============
4
+ Trains FuzzyNeuralNetwork models for all four disaster types.
5
+
6
+ Usage:
7
+ python train_model.py # Train all
8
+ python train_model.py --disaster flood # Train one
9
+ python train_model.py --disaster flood --epochs 300
10
+
11
+ Synthetic Data Strategy:
12
+ Since real labeled training data is rarely available in a single format,
13
+ this script generates physically-motivated synthetic datasets.
14
+
15
+ Each dataset is constructed so that the ground-truth risk label follows
16
+ the domain logic (e.g., high rainfall + low elevation + poor drainage → flood risk).
17
+
18
+ When you have real data:
19
+ Replace the generate_*_data() functions with your own data loaders.
20
+ The rest of the training pipeline stays identical.
21
  """
22
 
23
+ import torch
24
  import numpy as np
 
 
 
 
 
 
25
  import os
26
+ import argparse
27
+ from sklearn.model_selection import train_test_split
28
+ from sklearn.metrics import roc_auc_score, mean_absolute_error
29
 
30
+ from src.fuzzy_neural_network import FuzzyNeuralNetwork, FNNTrainer, save_model
31
+ from src.disaster_predictors import (
32
+ FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
 
35
+ MODEL_DIR = "models"
36
+ SEED = 42
37
+ np.random.seed(SEED)
38
+ torch.manual_seed(SEED)
 
 
 
39
 
 
 
 
 
40
 
41
  # ============================================================================
42
+ # SYNTHETIC DATA GENERATORS
43
  # ============================================================================
44
+ # Each function returns (X: np.ndarray, y: np.ndarray)
45
+ # X shape: (n_samples, n_features) — already normalized to [0, 1]
46
+ # y shape: (n_samples,) continuous risk score in [0, 1]
47
+
48
+ def generate_flood_data(n: int = 5000):
49
+ """
50
+ Flood risk is driven by:
51
+ - High rainfall
52
+ - Low elevation
53
+ - High soil saturation
54
+ - Low drainage capacity
55
+ - Close proximity to rivers
56
+ """
57
+ rng = np.random.default_rng(SEED)
58
+
59
+ rainfall_norm = rng.beta(2, 5, n) # Skewed: most days low rainfall
60
+ elevation_norm = rng.beta(3, 2, n) # Skewed: most areas higher ground
61
+ slope_norm = rng.beta(2, 5, n)
62
+ soil_sat_norm = rng.beta(2, 3, n)
63
+ dist_river_norm = rng.beta(2, 2, n)
64
+ drainage_norm = rng.beta(3, 2, n) # Most areas have decent drainage
65
+ hist_flood_norm = rng.beta(1.5, 3, n)
66
+ pop_density_norm = rng.beta(2, 2, n)
67
+
68
+ X = np.column_stack([
69
+ rainfall_norm, elevation_norm, slope_norm, soil_sat_norm,
70
+ dist_river_norm, drainage_norm, hist_flood_norm, pop_density_norm
71
+ ])
72
+
73
+ # Domain-informed risk formula
74
+ risk = (
75
+ 0.35 * rainfall_norm +
76
+ 0.25 * (1 - elevation_norm) + # Low elevation → higher risk
77
+ 0.15 * soil_sat_norm +
78
+ 0.10 * (1 - drainage_norm) + # Poor drainage → higher risk
79
+ 0.08 * (1 - dist_river_norm) + # Close to river → higher risk
80
+ 0.07 * hist_flood_norm
81
+ )
82
+
83
+ # Add noise and clip
84
+ risk += rng.normal(0, 0.05, n)
85
+ y = np.clip(risk, 0.0, 1.0).astype(np.float32)
86
+
87
+ return X.astype(np.float32), y
88
+
89
+
90
+ def generate_cyclone_data(n: int = 3000):
91
+ rng = np.random.default_rng(SEED + 1)
92
+
93
+ wind_norm = rng.beta(2, 5, n)
94
+ pressure_norm = rng.beta(3, 2, n) # Higher value = lower pressure = worse
95
+ sst_norm = rng.beta(3, 3, n)
96
+ curvature_norm = rng.beta(2, 3, n)
97
+ dist_coast_norm = rng.beta(2, 2, n)
98
+ surge_norm = rng.beta(2, 4, n)
99
+ moisture_norm = rng.beta(3, 3, n)
100
+ shear_norm = rng.beta(2, 3, n) # High shear weakens cyclones
101
+
102
+ X = np.column_stack([
103
+ wind_norm, pressure_norm, sst_norm, curvature_norm,
104
+ dist_coast_norm, surge_norm, moisture_norm, shear_norm
105
+ ])
106
+
107
+ risk = (
108
+ 0.30 * wind_norm +
109
+ 0.25 * (1 - pressure_norm) + # Low pressure = higher intensity
110
+ 0.15 * sst_norm + # Warm water feeds cyclones
111
+ 0.10 * surge_norm +
112
+ 0.10 * (1 - dist_coast_norm) +
113
+ 0.05 * moisture_norm +
114
+ 0.05 * (1 - shear_norm) # Low shear = stronger cyclone
115
+ )
116
+
117
+ risk += rng.normal(0, 0.05, n)
118
+ y = np.clip(risk, 0.0, 1.0).astype(np.float32)
119
+
120
+ return X.astype(np.float32), y
121
+
122
+
123
+ def generate_landslide_data(n: int = 4000):
124
+ rng = np.random.default_rng(SEED + 2)
125
+
126
+ slope_norm = rng.beta(2, 3, n)
127
+ rainfall_norm = rng.beta(2, 5, n)
128
+ soil_norm = rng.beta(3, 2, n) # Higher = more stable soil
129
+ veg_norm = rng.beta(3, 2, n) # Higher = more vegetation = more stable
130
+ seismic_norm = rng.beta(1.5, 4, n)
131
+ fault_norm = rng.beta(2, 2, n) # Higher = farther from fault
132
+ aspect_norm = rng.beta(2, 2, n)
133
+ hist_norm = rng.beta(1.5, 4, n)
134
+
135
+ X = np.column_stack([
136
+ slope_norm, rainfall_norm, soil_norm, veg_norm,
137
+ seismic_norm, fault_norm, aspect_norm, hist_norm
138
+ ])
139
+
140
+ risk = (
141
+ 0.30 * slope_norm +
142
+ 0.25 * rainfall_norm +
143
+ 0.15 * (1 - soil_norm) + # Unstable soil → higher risk
144
+ 0.10 * (1 - veg_norm) + # No vegetation → higher risk
145
+ 0.10 * seismic_norm +
146
+ 0.05 * (1 - fault_norm) + # Close to fault → higher risk
147
+ 0.05 * hist_norm
148
+ )
149
+
150
+ risk += rng.normal(0, 0.05, n)
151
+ y = np.clip(risk, 0.0, 1.0).astype(np.float32)
152
+
153
+ return X.astype(np.float32), y
154
+
155
+
156
+ def generate_earthquake_data(n: int = 3000):
157
+ rng = np.random.default_rng(SEED + 3)
158
+
159
+ hist_seism_norm = rng.beta(2, 4, n)
160
+ fault_norm = rng.beta(2, 2, n) # Higher = farther from fault
161
+ liquef_norm = rng.beta(2, 4, n)
162
+ depth_norm = rng.beta(3, 2, n) # Higher = deeper = less damage
163
+ stress_norm = rng.beta(2, 3, n)
164
+ vuln_norm = rng.beta(2, 3, n)
165
+ pop_norm = rng.beta(2, 2, n)
166
+ amp_norm = rng.beta(2, 3, n)
167
+
168
+ X = np.column_stack([
169
+ hist_seism_norm, fault_norm, liquef_norm, depth_norm,
170
+ stress_norm, vuln_norm, pop_norm, amp_norm
171
+ ])
172
+
173
+ risk = (
174
+ 0.25 * hist_seism_norm +
175
+ 0.20 * (1 - fault_norm) + # Close to fault = more risk
176
+ 0.15 * liquef_norm +
177
+ 0.10 * (1 - depth_norm) + # Shallow = more damage
178
+ 0.10 * stress_norm +
179
+ 0.10 * vuln_norm +
180
+ 0.05 * pop_norm +
181
+ 0.05 * amp_norm
182
+ )
183
+
184
+ risk += rng.normal(0, 0.05, n)
185
+ y = np.clip(risk, 0.0, 1.0).astype(np.float32)
186
+
187
+ return X.astype(np.float32), y
188
+
189
+
190
+ DATA_GENERATORS = {
191
+ "flood": (generate_flood_data, FLOOD_FEATURES),
192
+ "cyclone": (generate_cyclone_data, CYCLONE_FEATURES),
193
+ "landslide": (generate_landslide_data, LANDSLIDE_FEATURES),
194
+ "earthquake": (generate_earthquake_data, EARTHQUAKE_FEATURES),
195
  }
196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  # ============================================================================
199
+ # TRAINING PIPELINE
200
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ def evaluate_model(model: FuzzyNeuralNetwork, X: torch.Tensor, y: torch.Tensor) -> dict:
203
+ model.eval()
204
+ with torch.no_grad():
205
+ preds = model(X).numpy()
206
+ y_np = y.numpy()
207
+
208
+ # Binarize at 0.5 for AUC
209
+ try:
210
+ auc = roc_auc_score((y_np > 0.5).astype(int), preds)
211
+ except Exception:
212
+ auc = float('nan')
213
+
214
+ mae = mean_absolute_error(y_np, preds)
215
+
216
+ return {
217
+ "MAE": round(float(mae), 4),
218
+ "AUC-ROC": round(float(auc), 4),
219
+ "Mean Prediction": round(float(preds.mean()), 4),
220
+ "Std Prediction": round(float(preds.std()), 4),
221
+ }
222
+
223
+
224
+ def train_disaster_model(disaster_type: str, epochs: int = 200, n_samples: int = None):
225
+ print(f"\n{'='*60}")
226
+ print(f" Training FNN for: {disaster_type.upper()}")
227
+ print(f"{'='*60}")
228
+
229
+ generator_fn, feature_names = DATA_GENERATORS[disaster_type]
230
+ n = n_samples or {"flood": 5000, "cyclone": 3000, "landslide": 4000, "earthquake": 3000}[disaster_type]
231
+
232
+ print(f"Generating {n} synthetic samples...")
233
+ X, y = generator_fn(n)
234
+
235
+ # Train/val/test split
236
+ X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.15, random_state=SEED)
237
+ X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.15, random_state=SEED)
238
+
239
+ print(f" Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")
240
+
241
+ # Tensors
242
+ X_train_t = torch.tensor(X_train)
243
+ y_train_t = torch.tensor(y_train)
244
+ X_val_t = torch.tensor(X_val)
245
+ y_val_t = torch.tensor(y_val)
246
+ X_test_t = torch.tensor(X_test)
247
+ y_test_t = torch.tensor(y_test)
248
+
249
+ # Model
250
+ n_features = len(feature_names)
251
+ model = FuzzyNeuralNetwork(
252
+ n_features=n_features,
253
+ n_terms=3,
254
+ hidden_dims=[64, 32],
255
+ dropout=0.2
256
+ )
257
+
258
+ print(f" Model: FNN with {n_features} inputs, 3 fuzzy terms, 64→32 deep head")
259
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
260
+ print(f" Trainable parameters: {total_params:,}")
261
+
262
+ # Train
263
+ trainer = FNNTrainer(model, lr=1e-3, weight_decay=1e-4)
264
+ history = trainer.fit(
265
+ X_train_t, y_train_t,
266
+ X_val_t, y_val_t,
267
+ epochs=epochs, batch_size=64, patience=25
268
+ )
269
+
270
+ # Evaluate
271
+ print("\n Test set evaluation:")
272
+ metrics = evaluate_model(model, X_test_t, y_test_t)
273
+ for k, v in metrics.items():
274
+ print(f" {k}: {v}")
275
+
276
+ # Save
277
+ os.makedirs(MODEL_DIR, exist_ok=True)
278
+ model_path = os.path.join(MODEL_DIR, f"fnn_{disaster_type}_model.pt")
279
+ save_model(model, model_path, feature_names)
280
+
281
+ # Save feature names as text too
282
+ feat_path = os.path.join(MODEL_DIR, "feature_names", f"{disaster_type}_features.txt")
283
+ os.makedirs(os.path.dirname(feat_path), exist_ok=True)
284
+ with open(feat_path, "w") as f:
285
+ f.write("\n".join(feature_names))
286
+
287
+ print(f"\n Model saved to: {model_path}")
288
+ return metrics
289
+
290
+
291
+ def train_all(epochs: int = 200):
292
+ results = {}
293
+ for disaster_type in DATA_GENERATORS:
294
+ metrics = train_disaster_model(disaster_type, epochs=epochs)
295
+ results[disaster_type] = metrics
296
+
297
+ print("\n" + "="*60)
298
+ print(" TRAINING SUMMARY")
299
+ print("="*60)
300
+ for dt, metrics in results.items():
301
+ print(f" {dt.upper():12s} | MAE: {metrics['MAE']:.4f} | AUC: {metrics['AUC-ROC']:.4f}")
302
+ print("="*60)
303
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
  # ============================================================================
306
+ # ENTRY POINT
307
  # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
+ if __name__ == "__main__":
310
+ parser = argparse.ArgumentParser(description="Train FNN disaster models")
311
+ parser.add_argument(
312
+ "--disaster",
313
+ choices=list(DATA_GENERATORS.keys()) + ["all"],
314
+ default="all",
315
+ help="Which disaster model to train"
316
+ )
317
+ parser.add_argument("--epochs", type=int, default=200)
318
+ parser.add_argument("--samples", type=int, default=None)
319
+
320
+ args = parser.parse_args()
321
+
322
+ if args.disaster == "all":
323
+ train_all(epochs=args.epochs)
324
+ else:
325
+ train_disaster_model(args.disaster, epochs=args.epochs, n_samples=args.samples)