cyclone-pred-api / src /train_model.py
clarindasusan's picture
Update src/train_model.py
ba7df03 verified
"""
train_model.py
==============
Trains FuzzyNeuralNetwork models for all four disaster types.
Usage:
python train_model.py # Train all
python train_model.py --disaster flood # Train one
python train_model.py --disaster flood --epochs 300
Synthetic Data Strategy:
Since real labeled training data is rarely available in a single format,
this script generates physically-motivated synthetic datasets.
Each dataset is constructed so that the ground-truth risk label follows
the domain logic (e.g., high rainfall + low elevation + poor drainage → flood risk).
When you have real data:
Replace the generate_*_data() functions with your own data loaders.
The rest of the training pipeline stays identical.
"""
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
DATA_DIR = os.path.join(BASE_DIR, "data")
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors
import os
import argparse
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, mean_absolute_error
from scipy.spatial import cKDTree
from src.fuzzy_neural_network import FuzzyNeuralNetwork, FNNTrainer, save_model
from src.disaster_predictors import (
FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES
)
MODEL_DIR = "models"
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
# ============================================================================
# SYNTHETIC DATA GENERATORS
# ============================================================================
# Each function returns (X: np.ndarray, y: np.ndarray)
# X shape: (n_samples, n_features) — already normalized to [0, 1]
# y shape: (n_samples,) — continuous risk score in [0, 1]
def generate_flood_data(n: int = 5000):
rainfall = pd.read_csv(os.path.join(DATA_DIR, "rainfall_clean.csv"))
flood_hist = pd.read_csv(os.path.join(DATA_DIR, "flood_history_clean.csv"))
soil = pd.read_csv(os.path.join(DATA_DIR, "soil_moisture.csv"))
drainage = pd.read_csv(os.path.join(DATA_DIR, "drainage_capacity.csv"))
rivers = pd.read_csv(os.path.join(DATA_DIR, "river_network.csv"))
elevation = pd.read_csv(os.path.join(DATA_DIR, "elevation.csv"))
# ==============================
# Prepare flood labels
# ==============================
# Ensure proper integer formatting
flood_hist["year"] = flood_hist["year"].astype(int)
flood_hist["month"] = flood_hist["month"].astype(int)
flood_hist["date"] = pd.to_datetime(
dict(
year=flood_hist["year"],
month=flood_hist["month"],
day=1
)
)
rainfall["date"] = pd.to_datetime(rainfall["date"])
soil["date"] = pd.to_datetime(soil["date"])
# Aggregate rainfall & soil monthly
rainfall_monthly = rainfall.groupby(
[rainfall["date"].dt.to_period("M"), "latitude", "longitude"]
)["rainfall_mm"].mean().reset_index()
rainfall_monthly["date"] = rainfall_monthly["date"].dt.to_timestamp()
soil_monthly = soil.groupby(
[soil["date"].dt.to_period("M"), "latitude", "longitude"]
)["soil_saturation_pct"].mean().reset_index()
soil_monthly["date"] = soil_monthly["date"].dt.to_timestamp()
# ==============================
# Spatial Nearest Join Function
# ==============================
def spatial_join(base_df, join_df, features):
nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree')
nbrs.fit(join_df[["latitude", "longitude"]])
distances, indices = nbrs.kneighbors(base_df[["latitude", "longitude"]])
joined = join_df.iloc[indices.flatten()][features].reset_index(drop=True)
return joined
# ==============================
# Align all features to flood history
# ==============================
base = flood_hist.copy()
# Rainfall
rain_features = spatial_join(base, rainfall_monthly,
["rainfall_mm"])
base["rainfall_mm"] = rain_features["rainfall_mm"]
# Soil moisture
soil_features = spatial_join(base, soil_monthly,
["soil_saturation_pct"])
base["soil_saturation_pct"] = soil_features["soil_saturation_pct"]
# Drainage
drainage_features = spatial_join(base, drainage,
["drainage_capacity_index"])
base["drainage_capacity_index"] = drainage_features["drainage_capacity_index"]
# Elevation
elevation_features = spatial_join(base, elevation,
["elevation_m",
"flow_accumulation", "twi"])
for col in elevation_features.columns:
base[col] = elevation_features[col]
# Distance to nearest river
river_coords = rivers[["latitude", "longitude"]]
nbrs = NearestNeighbors(n_neighbors=1).fit(river_coords)
dist, _ = nbrs.kneighbors(base[["latitude", "longitude"]])
base["dist_river"] = dist
# ==============================
# Feature Selection
# ==============================
features = [
"rainfall_mm",
"elevation_m",
"soil_saturation_pct",
"dist_river",
"drainage_capacity_index",
"flow_accumulation",
"twi",
]
base = base.dropna(subset=features + ["severity_score"])
# ==============================
# Normalize
# ==============================
scaler = MinMaxScaler()
X = scaler.fit_transform(base[features])
# Normalize target
y = MinMaxScaler().fit_transform(
base[["severity_score"]]
).flatten()
return X.astype(np.float32), y.astype(np.float32)
# ── Replace generate_cyclone_data() entirely ──────────────────────────────
def generate_cyclone_data(n: int = 3000):
"""
Loads and joins the four cyclone CSVs:
- cyclone_tracks_clean.csv (spine)
- sea_surface_temp.csv (spatial join)
- atmospheric_moisture.csv (spatial join)
- wind_shear.csv (spatial join)
Returns X (normalized), y (risk_score) matching CYCLONE_FEATURES order.
"""
def nearest_merge(base_df, aux_df, cols):
tree = cKDTree(aux_df[["latitude", "longitude"]].values)
_, idxs = tree.query(base_df[["latitude", "longitude"]].values)
base_df = base_df.copy()
for col in cols:
base_df[col] = aux_df[col].iloc[idxs].values
return base_df
# ── Load ──────────────────────────────────────────────────────────────
tracks = pd.read_csv(os.path.join(DATA_DIR, "cyclone_tracks_clean.csv"))
sst = pd.read_csv(os.path.join(DATA_DIR, "sea_surface_temp.csv"))
moist = pd.read_csv(os.path.join(DATA_DIR, "atmospheric_moisture.csv"))
shear = pd.read_csv(os.path.join(DATA_DIR, "wind_shear.csv"))
# Normalise column names
for df in (tracks, sst, moist, shear):
df.columns = df.columns.str.lower().str.strip()
base = tracks.copy()
# ── Spatial joins ─────────────────────────────────────────────────────
base = nearest_merge(base, sst, ["sea_surface_temp_c"])
base = nearest_merge(base, moist, ["atmospheric_moisture"])
base = nearest_merge(base, shear, ["shear_index"])
# ── Validate required columns present ─────────────────────────────────
required = [
"wind_speed_kmh", "central_pressure_hpa", "sea_surface_temp_c",
"track_curvature", "distance_to_coast_km", "storm_surge_potential",
"atmospheric_moisture", "shear_index",
]
missing = [c for c in required if c not in base.columns]
if missing:
raise ValueError(
f"Cyclone data missing columns after join: {missing}\n"
f"Available columns: {list(base.columns)}"
)
base = base.dropna(subset=required)
# ── Build risk label ───────────────────────────────────────────────────
# Use severity_score if it exists in tracks, otherwise derive it
if "severity_score" in base.columns and base["severity_score"].notna().sum() > 0:
base["risk_score"] = MinMaxScaler().fit_transform(
base[["severity_score"]]
).flatten()
else:
wind_norm = np.clip(base["wind_speed_kmh"] / 350.0, 0, 1)
pressure_norm = np.clip(
(1013 - base["central_pressure_hpa"]) / (1013 - 870), 0, 1
)
coast_norm = np.clip(1 - base["distance_to_coast_km"] / 500.0, 0, 1)
surge_norm = np.clip(base["storm_surge_potential"], 0, 1)
sst_bonus = np.clip((base["sea_surface_temp_c"] - 26) / 9, 0, 1)
shear_penalty = np.clip(base["shear_index"], 0, 1)
base["risk_score"] = np.clip(
0.30 * wind_norm +
0.25 * pressure_norm +
0.20 * coast_norm +
0.15 * surge_norm +
0.10 * sst_bonus -
0.10 * shear_penalty + # high shear weakens cyclones
np.random.normal(0, 0.02, len(base)),
0.0, 1.0
)
# ── Normalise features (mirrors FEATURE_RANGES in disaster_predictors) ─
from src.disaster_predictors import FEATURE_RANGES
X = np.zeros((len(base), len(required)), dtype=np.float32)
for i, feat in enumerate(required):
lo, hi = FEATURE_RANGES[feat]
X[:, i] = np.clip(
(base[feat].values - lo) / (hi - lo + 1e-8), 0.0, 1.0
)
y = base["risk_score"].values.astype(np.float32)
return X, y
def generate_landslide_data(n: int = 4000):
print("[Landslide] Using REAL data loader")
def nearest_merge(base_df, aux_df, cols,
base_lat="latitude", base_lon="longitude",
aux_lat="latitude", aux_lon="longitude"):
if aux_lat not in aux_df.columns or aux_lon not in aux_df.columns:
raise ValueError(
f"nearest_merge: aux_df missing lat/lon. Has: {list(aux_df.columns)}"
)
tree = cKDTree(aux_df[[aux_lat, aux_lon]].values)
_, idxs = tree.query(base_df[[base_lat, base_lon]].values)
base_df = base_df.copy()
for col in cols:
if col not in aux_df.columns:
raise ValueError(
f"nearest_merge: '{col}' not in aux_df. Has: {list(aux_df.columns)}"
)
base_df[col] = aux_df[col].iloc[idxs].values
return base_df
# ── Load ──────────────────────────────────────────────────────────────
print("[Landslide] Loading CSVs...")
catalog = pd.read_csv(os.path.join(DATA_DIR, "Global_Landslide_Catalog_Export_rows.csv"))
veg = pd.read_csv(os.path.join(DATA_DIR, "vegetation_ndvi_aggregated.csv"))
faults = pd.read_csv(os.path.join(DATA_DIR, "fault_lines.csv"))
elev = pd.read_csv(os.path.join(DATA_DIR, "elevation.csv"))
rain = pd.read_csv(os.path.join(DATA_DIR, "rainfall_clean.csv"))
for df in (catalog, veg, faults, elev, rain):
df.columns = df.columns.str.lower().str.strip()
print(f"[Landslide] Catalog: {len(catalog)} rows, cols: {list(catalog.columns)}")
print(f"[Landslide] Veg cols: {list(veg.columns)}")
print(f"[Landslide] Fault cols: {list(faults.columns)}")
print(f"[Landslide] Elev cols: {list(elev.columns)}")
print(f"[Landslide] Rain cols: {list(rain.columns)}")
# ── Clean catalog spine ───────────────────────────────────────────────
catalog = catalog.dropna(subset=["latitude", "longitude"])
catalog["event_date"] = pd.to_datetime(
catalog["event_date"], errors="coerce"
)
catalog = catalog.dropna(subset=["event_date"])
print(f"[Landslide] After date clean: {len(catalog)} rows")
# ── historical_landslide_freq ─────────────────────────────────────────
catalog["lat_bin"] = (catalog["latitude"] / 0.5).round() * 0.5
catalog["lon_bin"] = (catalog["longitude"] / 0.5).round() * 0.5
freq_map = (
catalog.groupby(["lat_bin", "lon_bin"])
.size().reset_index(name="event_count")
)
catalog = catalog.merge(freq_map, on=["lat_bin", "lon_bin"], how="left")
catalog["historical_landslide_freq"] = (
catalog["event_count"] / catalog["event_count"].max()
).clip(0, 1)
base = catalog.copy()
# ── Vegetation ────────────────────────────────────────────────────────
print("[Landslide] Merging vegetation...")
if "vegetation_cover_pct" not in veg.columns and "ndvi" in veg.columns:
veg["vegetation_cover_pct"] = ((veg["ndvi"] + 1) / 2 * 100).clip(0, 100)
if "latitude" not in veg.columns or "longitude" not in veg.columns:
mean_cover = float(veg["vegetation_cover_pct"].mean())
print(f"[Landslide] Veg has no coordinates — broadcasting mean: {mean_cover:.1f}%")
base["vegetation_cover_pct"] = mean_cover
else:
base = nearest_merge(base, veg, ["vegetation_cover_pct"])
# ── Fault lines ───────────────────────────────────────────────────────
print("[Landslide] Merging fault lines...")
faults = faults.rename(columns={"seismic_hazard_index": "seismic_activity_index"})
fault_tree = cKDTree(np.radians(faults[["latitude", "longitude"]].values))
event_coords = np.radians(base[["latitude", "longitude"]].values)
dists_rad, idxs = fault_tree.query(event_coords)
base = base.copy()
base["distance_to_fault_km"] = dists_rad * 6371
base["seismic_activity_index"] = faults["seismic_activity_index"].iloc[idxs].values
# ── Elevation → aspect_index ──────────────────────────────────────────
print("[Landslide] Merging elevation...")
if "aspect_index" not in elev.columns:
if "aspect_degrees" in elev.columns:
elev["aspect_index"] = (elev["aspect_degrees"] / 360.0).clip(0, 1)
else:
elev["aspect_index"] = 0.5
base = nearest_merge(base, elev, ["slope_degrees","aspect_index"])
# ── Rainfall ──────────────────────────────────────────────────────────
print("[Landslide] Merging rainfall...")
rain["date"] = pd.to_datetime(rain["date"], errors="coerce")
rain_agg = (
rain.groupby(["latitude", "longitude"])["rainfall_mm"]
.mean().reset_index()
.rename(columns={"rainfall_mm": "rainfall_intensity_mmh"})
)
rain_agg["rainfall_intensity_mmh"] = rain_agg["rainfall_intensity_mmh"].clip(0, 200)
base = nearest_merge(base, rain_agg, ["rainfall_intensity_mmh"])
# ── soil_type_index proxy ─────────────────────────────────────────────
slope_norm = np.clip(base["slope_degrees"] / 90.0, 0, 1)
veg_norm = np.clip(base["vegetation_cover_pct"] / 100.0, 0, 1)
rain_norm = np.clip(base["rainfall_intensity_mmh"] / 200.0, 0, 1)
base["soil_type_index"] = np.clip( # ← FIX 1: now saved
1.0 - (0.4 * slope_norm + 0.3 * (1 - veg_norm) + 0.3 * rain_norm), 0, 1
)
# ── Risk label ────────────────────────────────────────────────────────
if "fatality_count" in base.columns: # ← FIX 2: no base.get()
base["fatality_count"] = pd.to_numeric(
base["fatality_count"], errors="coerce"
).fillna(0)
else:
base["fatality_count"] = 0.0
size_map = {"small": 0.2, "medium": 0.5, "large": 0.8,
"very_large": 1.0, "unknown": 0.3}
if "landslide_size" in base.columns: # ← FIX 2: no base.get()
base["size_score"] = (
base["landslide_size"]
.astype(str).str.lower().str.strip()
.map(size_map).fillna(0.3)
)
else:
base["size_score"] = 0.3
max_fatal = base["fatality_count"].max()
fatality_norm = (
np.log1p(base["fatality_count"]) / np.log1p(max_fatal + 1)
).clip(0, 1).values
base["risk_score"] = np.clip(
0.35 * base["size_score"] +
0.25 * base["historical_landslide_freq"] +
0.20 * fatality_norm +
0.15 * slope_norm +
0.05 * (1 - veg_norm) +
np.random.normal(0, 0.02, len(base)),
0.0, 1.0
)
print(f"[Landslide] Risk score: mean={base['risk_score'].mean():.3f}, "
f"std={base['risk_score'].std():.3f}, "
f">0.5: {(base['risk_score'] > 0.5).sum()} rows")
# ── Final feature matrix ──────────────────────────────────────────────
features = [
"slope_degrees","rainfall_intensity_mmh", "soil_type_index",
"vegetation_cover_pct", "seismic_activity_index",
"distance_to_fault_km", "aspect_index", "historical_landslide_freq",
]
# Verify features match LANDSLIDE_FEATURES exactly
assert features == list(LANDSLIDE_FEATURES), (
f"Feature mismatch!\n train: {features}\n"
f" predictor: {list(LANDSLIDE_FEATURES)}"
)
base = base.dropna(subset=features + ["risk_score"])
print(f"[Landslide] Final training rows: {len(base)}")
if len(base) < 50:
raise ValueError(
f"Only {len(base)} clean rows — check CSV paths and column names"
)
from src.disaster_predictors import FEATURE_RANGES
X = np.zeros((len(base), len(features)), dtype=np.float32)
for i, feat in enumerate(features):
lo, hi = FEATURE_RANGES[feat]
X[:, i] = np.clip(
(base[feat].values - lo) / (hi - lo + 1e-8), 0, 1
)
y = base["risk_score"].values.astype(np.float32)
return X, y
def generate_earthquake_data(n: int = 3000):
"""
Loads and joins earthquake datasets:
- earthquake_history.csv (spine + historical_seismicity, focal_depth_km, tectonic_stress_index)
- fault_lines_earthquake.csv (distance_to_fault_km, seismic_hazard_index)
- soil_liquefaction.csv (soil_liquefaction_index)
- vs30_bedrock.csv (bedrock_amplification)
- building_vulnerability.csv (building_vulnerability)
- population_earthquake.csv (population_density_norm)
"""
print("[Earthquake] Using REAL data loader")
def nearest_merge(base_df, aux_df, cols,
base_lat="latitude", base_lon="longitude",
aux_lat="latitude", aux_lon="longitude"):
if aux_lat not in aux_df.columns or aux_lon not in aux_df.columns:
raise ValueError(
f"nearest_merge: aux_df missing lat/lon. Has: {list(aux_df.columns)}"
)
tree = cKDTree(aux_df[[aux_lat, aux_lon]].values)
_, idxs = tree.query(base_df[[base_lat, base_lon]].values)
base_df = base_df.copy()
for col in cols:
if col not in aux_df.columns:
raise ValueError(
f"nearest_merge: '{col}' not in aux_df. "
f"Has: {list(aux_df.columns)}"
)
base_df[col] = aux_df[col].iloc[idxs].values
return base_df
# ── Load ──────────────────────────────────────────────────────────────
print("[Earthquake] Loading CSVs...")
history = pd.read_csv(os.path.join(DATA_DIR, "earthquake_history.csv"))
faults = pd.read_csv(os.path.join(DATA_DIR, "fault_lines_earthquake.csv"))
liquef = pd.read_csv(os.path.join(DATA_DIR, "soil_liquefaction.csv"))
vs30 = pd.read_csv(os.path.join(DATA_DIR, "vs30_bedrock.csv"))
bldg = pd.read_csv(os.path.join(DATA_DIR, "building_vulnerability.csv"))
pop = pd.read_csv(os.path.join(DATA_DIR, "population_earthquake.csv"))
for df in (history, faults, liquef, vs30, bldg, pop):
df.columns = df.columns.str.lower().str.strip()
print(f"[Earthquake] History: {len(history)} rows, cols: {list(history.columns)}")
print(f"[Earthquake] Faults cols: {list(faults.columns)}")
print(f"[Earthquake] Liquef cols: {list(liquef.columns)}")
print(f"[Earthquake] VS30 cols: {list(vs30.columns)}")
print(f"[Earthquake] Bldg cols: {list(bldg.columns)}")
print(f"[Earthquake] Pop cols: {list(pop.columns)}")
# ── Clean history spine ───────────────────────────────────────────────
history = history.dropna(subset=["latitude", "longitude"])
history["date"] = pd.to_datetime(history["date"], errors="coerce")
history = history.dropna(subset=["date"])
print(f"[Earthquake] After date clean: {len(history)} rows")
if len(history) == 0:
raise ValueError(
"earthquake_history has 0 rows after date parsing. "
f"Sample raw dates: {pd.read_csv(os.path.join(DATA_DIR, 'earthquake_history.csv'))['date'].head().tolist()}"
)
base = history.copy()
# ── Fault lines → distance_to_fault_km ───────────────────────────────
# fault_lines_earthquake already has distance_to_fault_km as a column
# but we still spatial-join to get the nearest fault's values
print("[Earthquake] Merging fault lines...")
base = nearest_merge(base, faults, ["distance_to_fault_km"])
# ── Soil liquefaction ─────────────────────────────────────────────────
print("[Earthquake] Merging soil liquefaction...")
base = nearest_merge(base, liquef, ["soil_liquefaction_index"])
# ── VS30 / bedrock amplification ──────────────────────────────────────
print("[Earthquake] Merging VS30 bedrock...")
base = nearest_merge(base, vs30, ["bedrock_amplification"])
# ── Building vulnerability ────────────────────────────────────────────
print("[Earthquake] Merging building vulnerability...")
base = nearest_merge(base, bldg, ["building_vulnerability"])
# ── Population density ────────────────────────────────────────────────
print("[Earthquake] Merging population...")
base = nearest_merge(base, pop, ["population_density_norm"])
# ── Validate all required columns present ─────────────────────────────
required = [
"historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index",
"focal_depth_km", "tectonic_stress_index", "building_vulnerability",
"population_density_norm", "bedrock_amplification",
]
missing = [c for c in required if c not in base.columns]
if missing:
raise ValueError(
f"Missing columns after all merges: {missing}\n"
f"Available: {list(base.columns)}"
)
base = base.dropna(subset=required)
print(f"[Earthquake] Rows after dropna: {len(base)}")
if len(base) < 50:
raise ValueError(
f"Only {len(base)} clean rows — check CSV paths and column names"
)
# ── Risk label ────────────────────────────────────────────────────────
# Use magnitude if available, otherwise derive from features
if "magnitude" in base.columns:
base["magnitude"] = pd.to_numeric(base["magnitude"], errors="coerce").fillna(0)
mag_norm = np.clip((base["magnitude"] - 2.0) / 7.0, 0, 1) # scale 2–9
else:
mag_norm = pd.Series(np.zeros(len(base)))
depth_norm = np.clip(base["focal_depth_km"] / 700.0, 0, 1)
fault_norm = np.clip(base["distance_to_fault_km"] / 200.0, 0, 1)
liquef_norm = np.clip(base["soil_liquefaction_index"], 0, 1)
vuln_norm = np.clip(base["building_vulnerability"], 0, 1)
pop_norm = np.clip(base["population_density_norm"], 0, 1)
amp_norm = np.clip(base["bedrock_amplification"], 0, 1)
stress_norm = np.clip(base["tectonic_stress_index"], 0, 1)
seism_norm = np.clip(base["historical_seismicity"], 0, 1)
base["risk_score"] = np.clip(
0.25 * mag_norm.values +
0.20 * (1 - depth_norm) + # shallow = more damage
0.15 * (1 - fault_norm) + # close to fault = more risk
0.15 * liquef_norm +
0.10 * vuln_norm +
0.05 * pop_norm +
0.05 * amp_norm +
0.05 * seism_norm +
np.random.normal(0, 0.02, len(base)),
0.0, 1.0
)
print(f"[Earthquake] Risk score: mean={base['risk_score'].mean():.3f}, "
f"std={base['risk_score'].std():.3f}, "
f">0.5: {(base['risk_score'] > 0.5).sum()} rows")
# ── Normalise features ────────────────────────────────────────────────
features = [
"historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index",
"focal_depth_km", "tectonic_stress_index", "building_vulnerability",
"population_density_norm", "bedrock_amplification",
]
# Add immediately before the assert
print(f"[Earthquake] Columns available before assert: {list(base.columns)}")
print(f"[Earthquake] Required: {required}")
print(f"[Earthquake] Missing: {[c for c in required if c not in base.columns]}")
assert features == list(EARTHQUAKE_FEATURES), (
f"Feature mismatch!\n train: {features}\n"
f" predictor: {list(EARTHQUAKE_FEATURES)}"
)
from src.disaster_predictors import FEATURE_RANGES
X = np.zeros((len(base), len(features)), dtype=np.float32)
for i, feat in enumerate(features):
lo, hi = FEATURE_RANGES[feat]
X[:, i] = np.clip(
(base[feat].values - lo) / (hi - lo + 1e-8), 0, 1
)
y = base["risk_score"].values.astype(np.float32)
return X, y
DATA_GENERATORS = {
"flood": (generate_flood_data, FLOOD_FEATURES),
"cyclone": (generate_cyclone_data, CYCLONE_FEATURES),
"landslide": (generate_landslide_data, LANDSLIDE_FEATURES),
"earthquake": (generate_earthquake_data, EARTHQUAKE_FEATURES),
}
# ============================================================================
# TRAINING PIPELINE
# ============================================================================
def evaluate_model(model: FuzzyNeuralNetwork, X: torch.Tensor, y: torch.Tensor) -> dict:
model.eval()
with torch.no_grad():
preds = model(X).numpy().flatten()
y_np = y.numpy()
threshold = float(np.median(y_np))
try:
auc = float(roc_auc_score((y_np > threshold).astype(int), preds))
except ValueError as e:
print(f" [Warning] AUC undefined: {e}")
auc = float("nan") # always float, never string
mae = mean_absolute_error(y_np, preds)
return {
"MAE": round(float(mae), 4),
"AUC-ROC": round(auc, 4) if not np.isnan(auc) else float("nan"),
"Mean Prediction": round(float(preds.mean()), 4),
"Mean Label": round(float(y_np.mean()), 4),
"Std Prediction": round(float(preds.std()), 4),
}
def train_disaster_model(disaster_type: str, epochs: int = 1000, n_samples: int = None):
print(f"\n{'='*60}")
print(f" Training FNN for: {disaster_type.upper()}")
print(f"{'='*60}")
generator_fn, feature_names = DATA_GENERATORS[disaster_type]
n = n_samples or {
"flood": 5000, "cyclone": 3000,
"landslide": 4000, "earthquake": 3000
}[disaster_type]
REAL_DATA_GENERATORS = {"flood", "cyclone", "landslide", "earthquake"}
if disaster_type in REAL_DATA_GENERATORS:
print(f"Loading real data for {disaster_type}...")
else:
print(f"Generating {n} synthetic samples...")
X, y = generator_fn(n)
print(f" Data loaded: X={X.shape}, y={y.shape}, "
f"y_mean={y.mean():.3f}, y_std={y.std():.3f}")
# Train/val/test split
X_trainval, X_test, y_trainval, y_test = train_test_split(
X, y, test_size=0.15, random_state=SEED
)
X_train, X_val, y_train, y_val = train_test_split(
X_trainval, y_trainval, test_size=0.15, random_state=SEED
)
print(f" Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")
# Tensors
X_train_t = torch.tensor(X_train)
y_train_t = torch.tensor(y_train)
X_val_t = torch.tensor(X_val)
y_val_t = torch.tensor(y_val)
X_test_t = torch.tensor(X_test)
y_test_t = torch.tensor(y_test)
# Model
n_features = len(feature_names)
model = FuzzyNeuralNetwork(
n_features=n_features,
n_terms=3,
hidden_dims=[64, 32],
dropout=0.2
)
print(f" Model: FNN with {n_features} inputs, 3 fuzzy terms, 64→32 deep head")
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f" Trainable parameters: {total_params:,}")
# Train
trainer = FNNTrainer(model, lr=1e-3, weight_decay=1e-4)
trainer.fit(
X_train_t, y_train_t,
X_val_t, y_val_t,
epochs=epochs, batch_size=64, patience=50
)
# Evaluate
print("\n Test set evaluation:")
metrics = evaluate_model(model, X_test_t, y_test_t)
for k, v in metrics.items():
print(f" {k}: {v}")
# Save
os.makedirs(MODEL_DIR, exist_ok=True)
model_path = os.path.join(MODEL_DIR, f"fnn_{disaster_type}_model.pt")
save_model(model, model_path, feature_names)
feat_path = os.path.join(MODEL_DIR, "feature_names", f"{disaster_type}_features.txt")
os.makedirs(os.path.dirname(feat_path), exist_ok=True)
with open(feat_path, "w") as f:
f.write("\n".join(feature_names))
print(f"\n Model saved to: {model_path}")
return metrics # ← always returns, never None
def train_all(epochs: int = 200):
results = {}
for disaster_type in DATA_GENERATORS:
try:
metrics = train_disaster_model(disaster_type, epochs=epochs)
if metrics is None:
raise RuntimeError("train_disaster_model returned None")
results[disaster_type] = metrics
except Exception as e:
print(f"\n [ERROR] {disaster_type} training failed: {e}")
import traceback
traceback.print_exc()
results[disaster_type] = {
"MAE": float("nan"),
"AUC-ROC": float("nan"),
"Mean Prediction": float("nan"),
"Std Prediction": float("nan"),
}
print("\n" + "="*60)
print(" TRAINING SUMMARY")
print("="*60)
for dt, metrics in results.items():
auc = metrics["AUC-ROC"]
mae = metrics["MAE"]
auc_str = f"{auc:.4f}" if isinstance(auc, float) and not np.isnan(auc) else "nan"
mae_str = f"{mae:.4f}" if isinstance(mae, float) and not np.isnan(mae) else "nan"
print(f" {dt.upper():12s} | MAE: {mae_str} | AUC: {auc_str}")
print("="*60)
# ============================================================================
# ENTRY POINT
# ============================================================================
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train FNN disaster models")
parser.add_argument(
"--disaster",
choices=list(DATA_GENERATORS.keys()) + ["all"],
default="all",
help="Which disaster model to train"
)
parser.add_argument("--epochs", type=int, default=200)
parser.add_argument("--samples", type=int, default=None)
args = parser.parse_args()
if args.disaster == "all":
train_all(epochs=args.epochs)
else:
train_disaster_model(args.disaster, epochs=args.epochs, n_samples=args.samples)