""" train_model.py ============== Trains FuzzyNeuralNetwork models for all four disaster types. Usage: python train_model.py # Train all python train_model.py --disaster flood # Train one python train_model.py --disaster flood --epochs 300 Synthetic Data Strategy: Since real labeled training data is rarely available in a single format, this script generates physically-motivated synthetic datasets. Each dataset is constructed so that the ground-truth risk label follows the domain logic (e.g., high rainfall + low elevation + poor drainage → flood risk). When you have real data: Replace the generate_*_data() functions with your own data loaders. The rest of the training pipeline stays identical. """ import os BASE_DIR = os.path.dirname(os.path.abspath(__file__)) DATA_DIR = os.path.join(BASE_DIR, "data") import torch import numpy as np import pandas as pd from sklearn.preprocessing import MinMaxScaler from sklearn.neighbors import NearestNeighbors import os import argparse from sklearn.model_selection import train_test_split from sklearn.metrics import roc_auc_score, mean_absolute_error from scipy.spatial import cKDTree from src.fuzzy_neural_network import FuzzyNeuralNetwork, FNNTrainer, save_model from src.disaster_predictors import ( FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES ) MODEL_DIR = "models" SEED = 42 np.random.seed(SEED) torch.manual_seed(SEED) # ============================================================================ # SYNTHETIC DATA GENERATORS # ============================================================================ # Each function returns (X: np.ndarray, y: np.ndarray) # X shape: (n_samples, n_features) — already normalized to [0, 1] # y shape: (n_samples,) — continuous risk score in [0, 1] def generate_flood_data(n: int = 5000): rainfall = pd.read_csv(os.path.join(DATA_DIR, "rainfall_clean.csv")) flood_hist = pd.read_csv(os.path.join(DATA_DIR, "flood_history_clean.csv")) soil = pd.read_csv(os.path.join(DATA_DIR, "soil_moisture.csv")) drainage = pd.read_csv(os.path.join(DATA_DIR, "drainage_capacity.csv")) rivers = pd.read_csv(os.path.join(DATA_DIR, "river_network.csv")) elevation = pd.read_csv(os.path.join(DATA_DIR, "elevation.csv")) # ============================== # Prepare flood labels # ============================== # Ensure proper integer formatting flood_hist["year"] = flood_hist["year"].astype(int) flood_hist["month"] = flood_hist["month"].astype(int) flood_hist["date"] = pd.to_datetime( dict( year=flood_hist["year"], month=flood_hist["month"], day=1 ) ) rainfall["date"] = pd.to_datetime(rainfall["date"]) soil["date"] = pd.to_datetime(soil["date"]) # Aggregate rainfall & soil monthly rainfall_monthly = rainfall.groupby( [rainfall["date"].dt.to_period("M"), "latitude", "longitude"] )["rainfall_mm"].mean().reset_index() rainfall_monthly["date"] = rainfall_monthly["date"].dt.to_timestamp() soil_monthly = soil.groupby( [soil["date"].dt.to_period("M"), "latitude", "longitude"] )["soil_saturation_pct"].mean().reset_index() soil_monthly["date"] = soil_monthly["date"].dt.to_timestamp() # ============================== # Spatial Nearest Join Function # ============================== def spatial_join(base_df, join_df, features): nbrs = NearestNeighbors(n_neighbors=1, algorithm='ball_tree') nbrs.fit(join_df[["latitude", "longitude"]]) distances, indices = nbrs.kneighbors(base_df[["latitude", "longitude"]]) joined = join_df.iloc[indices.flatten()][features].reset_index(drop=True) return joined # ============================== # Align all features to flood history # ============================== base = flood_hist.copy() # Rainfall rain_features = spatial_join(base, rainfall_monthly, ["rainfall_mm"]) base["rainfall_mm"] = rain_features["rainfall_mm"] # Soil moisture soil_features = spatial_join(base, soil_monthly, ["soil_saturation_pct"]) base["soil_saturation_pct"] = soil_features["soil_saturation_pct"] # Drainage drainage_features = spatial_join(base, drainage, ["drainage_capacity_index"]) base["drainage_capacity_index"] = drainage_features["drainage_capacity_index"] # Elevation elevation_features = spatial_join(base, elevation, ["elevation_m", "flow_accumulation", "twi"]) for col in elevation_features.columns: base[col] = elevation_features[col] # Distance to nearest river river_coords = rivers[["latitude", "longitude"]] nbrs = NearestNeighbors(n_neighbors=1).fit(river_coords) dist, _ = nbrs.kneighbors(base[["latitude", "longitude"]]) base["dist_river"] = dist # ============================== # Feature Selection # ============================== features = [ "rainfall_mm", "elevation_m", "soil_saturation_pct", "dist_river", "drainage_capacity_index", "flow_accumulation", "twi", ] base = base.dropna(subset=features + ["severity_score"]) # ============================== # Normalize # ============================== scaler = MinMaxScaler() X = scaler.fit_transform(base[features]) # Normalize target y = MinMaxScaler().fit_transform( base[["severity_score"]] ).flatten() return X.astype(np.float32), y.astype(np.float32) # ── Replace generate_cyclone_data() entirely ────────────────────────────── def generate_cyclone_data(n: int = 3000): """ Loads and joins the four cyclone CSVs: - cyclone_tracks_clean.csv (spine) - sea_surface_temp.csv (spatial join) - atmospheric_moisture.csv (spatial join) - wind_shear.csv (spatial join) Returns X (normalized), y (risk_score) matching CYCLONE_FEATURES order. """ def nearest_merge(base_df, aux_df, cols): tree = cKDTree(aux_df[["latitude", "longitude"]].values) _, idxs = tree.query(base_df[["latitude", "longitude"]].values) base_df = base_df.copy() for col in cols: base_df[col] = aux_df[col].iloc[idxs].values return base_df # ── Load ────────────────────────────────────────────────────────────── tracks = pd.read_csv(os.path.join(DATA_DIR, "cyclone_tracks_clean.csv")) sst = pd.read_csv(os.path.join(DATA_DIR, "sea_surface_temp.csv")) moist = pd.read_csv(os.path.join(DATA_DIR, "atmospheric_moisture.csv")) shear = pd.read_csv(os.path.join(DATA_DIR, "wind_shear.csv")) # Normalise column names for df in (tracks, sst, moist, shear): df.columns = df.columns.str.lower().str.strip() base = tracks.copy() # ── Spatial joins ───────────────────────────────────────────────────── base = nearest_merge(base, sst, ["sea_surface_temp_c"]) base = nearest_merge(base, moist, ["atmospheric_moisture"]) base = nearest_merge(base, shear, ["shear_index"]) # ── Validate required columns present ───────────────────────────────── required = [ "wind_speed_kmh", "central_pressure_hpa", "sea_surface_temp_c", "track_curvature", "distance_to_coast_km", "storm_surge_potential", "atmospheric_moisture", "shear_index", ] missing = [c for c in required if c not in base.columns] if missing: raise ValueError( f"Cyclone data missing columns after join: {missing}\n" f"Available columns: {list(base.columns)}" ) base = base.dropna(subset=required) # ── Build risk label ─────────────────────────────────────────────────── # Use severity_score if it exists in tracks, otherwise derive it if "severity_score" in base.columns and base["severity_score"].notna().sum() > 0: base["risk_score"] = MinMaxScaler().fit_transform( base[["severity_score"]] ).flatten() else: wind_norm = np.clip(base["wind_speed_kmh"] / 350.0, 0, 1) pressure_norm = np.clip( (1013 - base["central_pressure_hpa"]) / (1013 - 870), 0, 1 ) coast_norm = np.clip(1 - base["distance_to_coast_km"] / 500.0, 0, 1) surge_norm = np.clip(base["storm_surge_potential"], 0, 1) sst_bonus = np.clip((base["sea_surface_temp_c"] - 26) / 9, 0, 1) shear_penalty = np.clip(base["shear_index"], 0, 1) base["risk_score"] = np.clip( 0.30 * wind_norm + 0.25 * pressure_norm + 0.20 * coast_norm + 0.15 * surge_norm + 0.10 * sst_bonus - 0.10 * shear_penalty + # high shear weakens cyclones np.random.normal(0, 0.02, len(base)), 0.0, 1.0 ) # ── Normalise features (mirrors FEATURE_RANGES in disaster_predictors) ─ from src.disaster_predictors import FEATURE_RANGES X = np.zeros((len(base), len(required)), dtype=np.float32) for i, feat in enumerate(required): lo, hi = FEATURE_RANGES[feat] X[:, i] = np.clip( (base[feat].values - lo) / (hi - lo + 1e-8), 0.0, 1.0 ) y = base["risk_score"].values.astype(np.float32) return X, y def generate_landslide_data(n: int = 4000): print("[Landslide] Using REAL data loader") def nearest_merge(base_df, aux_df, cols, base_lat="latitude", base_lon="longitude", aux_lat="latitude", aux_lon="longitude"): if aux_lat not in aux_df.columns or aux_lon not in aux_df.columns: raise ValueError( f"nearest_merge: aux_df missing lat/lon. Has: {list(aux_df.columns)}" ) tree = cKDTree(aux_df[[aux_lat, aux_lon]].values) _, idxs = tree.query(base_df[[base_lat, base_lon]].values) base_df = base_df.copy() for col in cols: if col not in aux_df.columns: raise ValueError( f"nearest_merge: '{col}' not in aux_df. Has: {list(aux_df.columns)}" ) base_df[col] = aux_df[col].iloc[idxs].values return base_df # ── Load ────────────────────────────────────────────────────────────── print("[Landslide] Loading CSVs...") catalog = pd.read_csv(os.path.join(DATA_DIR, "Global_Landslide_Catalog_Export_rows.csv")) veg = pd.read_csv(os.path.join(DATA_DIR, "vegetation_ndvi_aggregated.csv")) faults = pd.read_csv(os.path.join(DATA_DIR, "fault_lines.csv")) elev = pd.read_csv(os.path.join(DATA_DIR, "elevation.csv")) rain = pd.read_csv(os.path.join(DATA_DIR, "rainfall_clean.csv")) for df in (catalog, veg, faults, elev, rain): df.columns = df.columns.str.lower().str.strip() print(f"[Landslide] Catalog: {len(catalog)} rows, cols: {list(catalog.columns)}") print(f"[Landslide] Veg cols: {list(veg.columns)}") print(f"[Landslide] Fault cols: {list(faults.columns)}") print(f"[Landslide] Elev cols: {list(elev.columns)}") print(f"[Landslide] Rain cols: {list(rain.columns)}") # ── Clean catalog spine ─────────────────────────────────────────────── catalog = catalog.dropna(subset=["latitude", "longitude"]) catalog["event_date"] = pd.to_datetime( catalog["event_date"], errors="coerce" ) catalog = catalog.dropna(subset=["event_date"]) print(f"[Landslide] After date clean: {len(catalog)} rows") # ── historical_landslide_freq ───────────────────────────────────────── catalog["lat_bin"] = (catalog["latitude"] / 0.5).round() * 0.5 catalog["lon_bin"] = (catalog["longitude"] / 0.5).round() * 0.5 freq_map = ( catalog.groupby(["lat_bin", "lon_bin"]) .size().reset_index(name="event_count") ) catalog = catalog.merge(freq_map, on=["lat_bin", "lon_bin"], how="left") catalog["historical_landslide_freq"] = ( catalog["event_count"] / catalog["event_count"].max() ).clip(0, 1) base = catalog.copy() # ── Vegetation ──────────────────────────────────────────────────────── print("[Landslide] Merging vegetation...") if "vegetation_cover_pct" not in veg.columns and "ndvi" in veg.columns: veg["vegetation_cover_pct"] = ((veg["ndvi"] + 1) / 2 * 100).clip(0, 100) if "latitude" not in veg.columns or "longitude" not in veg.columns: mean_cover = float(veg["vegetation_cover_pct"].mean()) print(f"[Landslide] Veg has no coordinates — broadcasting mean: {mean_cover:.1f}%") base["vegetation_cover_pct"] = mean_cover else: base = nearest_merge(base, veg, ["vegetation_cover_pct"]) # ── Fault lines ─────────────────────────────────────────────────────── print("[Landslide] Merging fault lines...") faults = faults.rename(columns={"seismic_hazard_index": "seismic_activity_index"}) fault_tree = cKDTree(np.radians(faults[["latitude", "longitude"]].values)) event_coords = np.radians(base[["latitude", "longitude"]].values) dists_rad, idxs = fault_tree.query(event_coords) base = base.copy() base["distance_to_fault_km"] = dists_rad * 6371 base["seismic_activity_index"] = faults["seismic_activity_index"].iloc[idxs].values # ── Elevation → aspect_index ────────────────────────────────────────── print("[Landslide] Merging elevation...") if "aspect_index" not in elev.columns: if "aspect_degrees" in elev.columns: elev["aspect_index"] = (elev["aspect_degrees"] / 360.0).clip(0, 1) else: elev["aspect_index"] = 0.5 base = nearest_merge(base, elev, ["slope_degrees","aspect_index"]) # ── Rainfall ────────────────────────────────────────────────────────── print("[Landslide] Merging rainfall...") rain["date"] = pd.to_datetime(rain["date"], errors="coerce") rain_agg = ( rain.groupby(["latitude", "longitude"])["rainfall_mm"] .mean().reset_index() .rename(columns={"rainfall_mm": "rainfall_intensity_mmh"}) ) rain_agg["rainfall_intensity_mmh"] = rain_agg["rainfall_intensity_mmh"].clip(0, 200) base = nearest_merge(base, rain_agg, ["rainfall_intensity_mmh"]) # ── soil_type_index proxy ───────────────────────────────────────────── slope_norm = np.clip(base["slope_degrees"] / 90.0, 0, 1) veg_norm = np.clip(base["vegetation_cover_pct"] / 100.0, 0, 1) rain_norm = np.clip(base["rainfall_intensity_mmh"] / 200.0, 0, 1) base["soil_type_index"] = np.clip( # ← FIX 1: now saved 1.0 - (0.4 * slope_norm + 0.3 * (1 - veg_norm) + 0.3 * rain_norm), 0, 1 ) # ── Risk label ──────────────────────────────────────────────────────── if "fatality_count" in base.columns: # ← FIX 2: no base.get() base["fatality_count"] = pd.to_numeric( base["fatality_count"], errors="coerce" ).fillna(0) else: base["fatality_count"] = 0.0 size_map = {"small": 0.2, "medium": 0.5, "large": 0.8, "very_large": 1.0, "unknown": 0.3} if "landslide_size" in base.columns: # ← FIX 2: no base.get() base["size_score"] = ( base["landslide_size"] .astype(str).str.lower().str.strip() .map(size_map).fillna(0.3) ) else: base["size_score"] = 0.3 max_fatal = base["fatality_count"].max() fatality_norm = ( np.log1p(base["fatality_count"]) / np.log1p(max_fatal + 1) ).clip(0, 1).values base["risk_score"] = np.clip( 0.35 * base["size_score"] + 0.25 * base["historical_landslide_freq"] + 0.20 * fatality_norm + 0.15 * slope_norm + 0.05 * (1 - veg_norm) + np.random.normal(0, 0.02, len(base)), 0.0, 1.0 ) print(f"[Landslide] Risk score: mean={base['risk_score'].mean():.3f}, " f"std={base['risk_score'].std():.3f}, " f">0.5: {(base['risk_score'] > 0.5).sum()} rows") # ── Final feature matrix ────────────────────────────────────────────── features = [ "slope_degrees","rainfall_intensity_mmh", "soil_type_index", "vegetation_cover_pct", "seismic_activity_index", "distance_to_fault_km", "aspect_index", "historical_landslide_freq", ] # Verify features match LANDSLIDE_FEATURES exactly assert features == list(LANDSLIDE_FEATURES), ( f"Feature mismatch!\n train: {features}\n" f" predictor: {list(LANDSLIDE_FEATURES)}" ) base = base.dropna(subset=features + ["risk_score"]) print(f"[Landslide] Final training rows: {len(base)}") if len(base) < 50: raise ValueError( f"Only {len(base)} clean rows — check CSV paths and column names" ) from src.disaster_predictors import FEATURE_RANGES X = np.zeros((len(base), len(features)), dtype=np.float32) for i, feat in enumerate(features): lo, hi = FEATURE_RANGES[feat] X[:, i] = np.clip( (base[feat].values - lo) / (hi - lo + 1e-8), 0, 1 ) y = base["risk_score"].values.astype(np.float32) return X, y def generate_earthquake_data(n: int = 3000): """ Loads and joins earthquake datasets: - earthquake_history.csv (spine + historical_seismicity, focal_depth_km, tectonic_stress_index) - fault_lines_earthquake.csv (distance_to_fault_km, seismic_hazard_index) - soil_liquefaction.csv (soil_liquefaction_index) - vs30_bedrock.csv (bedrock_amplification) - building_vulnerability.csv (building_vulnerability) - population_earthquake.csv (population_density_norm) """ print("[Earthquake] Using REAL data loader") def nearest_merge(base_df, aux_df, cols, base_lat="latitude", base_lon="longitude", aux_lat="latitude", aux_lon="longitude"): if aux_lat not in aux_df.columns or aux_lon not in aux_df.columns: raise ValueError( f"nearest_merge: aux_df missing lat/lon. Has: {list(aux_df.columns)}" ) tree = cKDTree(aux_df[[aux_lat, aux_lon]].values) _, idxs = tree.query(base_df[[base_lat, base_lon]].values) base_df = base_df.copy() for col in cols: if col not in aux_df.columns: raise ValueError( f"nearest_merge: '{col}' not in aux_df. " f"Has: {list(aux_df.columns)}" ) base_df[col] = aux_df[col].iloc[idxs].values return base_df # ── Load ────────────────────────────────────────────────────────────── print("[Earthquake] Loading CSVs...") history = pd.read_csv(os.path.join(DATA_DIR, "earthquake_history.csv")) faults = pd.read_csv(os.path.join(DATA_DIR, "fault_lines_earthquake.csv")) liquef = pd.read_csv(os.path.join(DATA_DIR, "soil_liquefaction.csv")) vs30 = pd.read_csv(os.path.join(DATA_DIR, "vs30_bedrock.csv")) bldg = pd.read_csv(os.path.join(DATA_DIR, "building_vulnerability.csv")) pop = pd.read_csv(os.path.join(DATA_DIR, "population_earthquake.csv")) for df in (history, faults, liquef, vs30, bldg, pop): df.columns = df.columns.str.lower().str.strip() print(f"[Earthquake] History: {len(history)} rows, cols: {list(history.columns)}") print(f"[Earthquake] Faults cols: {list(faults.columns)}") print(f"[Earthquake] Liquef cols: {list(liquef.columns)}") print(f"[Earthquake] VS30 cols: {list(vs30.columns)}") print(f"[Earthquake] Bldg cols: {list(bldg.columns)}") print(f"[Earthquake] Pop cols: {list(pop.columns)}") # ── Clean history spine ─────────────────────────────────────────────── history = history.dropna(subset=["latitude", "longitude"]) history["date"] = pd.to_datetime(history["date"], errors="coerce") history = history.dropna(subset=["date"]) print(f"[Earthquake] After date clean: {len(history)} rows") if len(history) == 0: raise ValueError( "earthquake_history has 0 rows after date parsing. " f"Sample raw dates: {pd.read_csv(os.path.join(DATA_DIR, 'earthquake_history.csv'))['date'].head().tolist()}" ) base = history.copy() # ── Fault lines → distance_to_fault_km ─────────────────────────────── # fault_lines_earthquake already has distance_to_fault_km as a column # but we still spatial-join to get the nearest fault's values print("[Earthquake] Merging fault lines...") base = nearest_merge(base, faults, ["distance_to_fault_km"]) # ── Soil liquefaction ───────────────────────────────────────────────── print("[Earthquake] Merging soil liquefaction...") base = nearest_merge(base, liquef, ["soil_liquefaction_index"]) # ── VS30 / bedrock amplification ────────────────────────────────────── print("[Earthquake] Merging VS30 bedrock...") base = nearest_merge(base, vs30, ["bedrock_amplification"]) # ── Building vulnerability ──────────────────────────────────────────── print("[Earthquake] Merging building vulnerability...") base = nearest_merge(base, bldg, ["building_vulnerability"]) # ── Population density ──────────────────────────────────────────────── print("[Earthquake] Merging population...") base = nearest_merge(base, pop, ["population_density_norm"]) # ── Validate all required columns present ───────────────────────────── required = [ "historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index", "focal_depth_km", "tectonic_stress_index", "building_vulnerability", "population_density_norm", "bedrock_amplification", ] missing = [c for c in required if c not in base.columns] if missing: raise ValueError( f"Missing columns after all merges: {missing}\n" f"Available: {list(base.columns)}" ) base = base.dropna(subset=required) print(f"[Earthquake] Rows after dropna: {len(base)}") if len(base) < 50: raise ValueError( f"Only {len(base)} clean rows — check CSV paths and column names" ) # ── Risk label ──────────────────────────────────────────────────────── # Use magnitude if available, otherwise derive from features if "magnitude" in base.columns: base["magnitude"] = pd.to_numeric(base["magnitude"], errors="coerce").fillna(0) mag_norm = np.clip((base["magnitude"] - 2.0) / 7.0, 0, 1) # scale 2–9 else: mag_norm = pd.Series(np.zeros(len(base))) depth_norm = np.clip(base["focal_depth_km"] / 700.0, 0, 1) fault_norm = np.clip(base["distance_to_fault_km"] / 200.0, 0, 1) liquef_norm = np.clip(base["soil_liquefaction_index"], 0, 1) vuln_norm = np.clip(base["building_vulnerability"], 0, 1) pop_norm = np.clip(base["population_density_norm"], 0, 1) amp_norm = np.clip(base["bedrock_amplification"], 0, 1) stress_norm = np.clip(base["tectonic_stress_index"], 0, 1) seism_norm = np.clip(base["historical_seismicity"], 0, 1) base["risk_score"] = np.clip( 0.25 * mag_norm.values + 0.20 * (1 - depth_norm) + # shallow = more damage 0.15 * (1 - fault_norm) + # close to fault = more risk 0.15 * liquef_norm + 0.10 * vuln_norm + 0.05 * pop_norm + 0.05 * amp_norm + 0.05 * seism_norm + np.random.normal(0, 0.02, len(base)), 0.0, 1.0 ) print(f"[Earthquake] Risk score: mean={base['risk_score'].mean():.3f}, " f"std={base['risk_score'].std():.3f}, " f">0.5: {(base['risk_score'] > 0.5).sum()} rows") # ── Normalise features ──────────────────────────────────────────────── features = [ "historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index", "focal_depth_km", "tectonic_stress_index", "building_vulnerability", "population_density_norm", "bedrock_amplification", ] # Add immediately before the assert print(f"[Earthquake] Columns available before assert: {list(base.columns)}") print(f"[Earthquake] Required: {required}") print(f"[Earthquake] Missing: {[c for c in required if c not in base.columns]}") assert features == list(EARTHQUAKE_FEATURES), ( f"Feature mismatch!\n train: {features}\n" f" predictor: {list(EARTHQUAKE_FEATURES)}" ) from src.disaster_predictors import FEATURE_RANGES X = np.zeros((len(base), len(features)), dtype=np.float32) for i, feat in enumerate(features): lo, hi = FEATURE_RANGES[feat] X[:, i] = np.clip( (base[feat].values - lo) / (hi - lo + 1e-8), 0, 1 ) y = base["risk_score"].values.astype(np.float32) return X, y DATA_GENERATORS = { "flood": (generate_flood_data, FLOOD_FEATURES), "cyclone": (generate_cyclone_data, CYCLONE_FEATURES), "landslide": (generate_landslide_data, LANDSLIDE_FEATURES), "earthquake": (generate_earthquake_data, EARTHQUAKE_FEATURES), } # ============================================================================ # TRAINING PIPELINE # ============================================================================ def evaluate_model(model: FuzzyNeuralNetwork, X: torch.Tensor, y: torch.Tensor) -> dict: model.eval() with torch.no_grad(): preds = model(X).numpy().flatten() y_np = y.numpy() threshold = float(np.median(y_np)) try: auc = float(roc_auc_score((y_np > threshold).astype(int), preds)) except ValueError as e: print(f" [Warning] AUC undefined: {e}") auc = float("nan") # always float, never string mae = mean_absolute_error(y_np, preds) return { "MAE": round(float(mae), 4), "AUC-ROC": round(auc, 4) if not np.isnan(auc) else float("nan"), "Mean Prediction": round(float(preds.mean()), 4), "Mean Label": round(float(y_np.mean()), 4), "Std Prediction": round(float(preds.std()), 4), } def train_disaster_model(disaster_type: str, epochs: int = 1000, n_samples: int = None): print(f"\n{'='*60}") print(f" Training FNN for: {disaster_type.upper()}") print(f"{'='*60}") generator_fn, feature_names = DATA_GENERATORS[disaster_type] n = n_samples or { "flood": 5000, "cyclone": 3000, "landslide": 4000, "earthquake": 3000 }[disaster_type] REAL_DATA_GENERATORS = {"flood", "cyclone", "landslide", "earthquake"} if disaster_type in REAL_DATA_GENERATORS: print(f"Loading real data for {disaster_type}...") else: print(f"Generating {n} synthetic samples...") X, y = generator_fn(n) print(f" Data loaded: X={X.shape}, y={y.shape}, " f"y_mean={y.mean():.3f}, y_std={y.std():.3f}") # Train/val/test split X_trainval, X_test, y_trainval, y_test = train_test_split( X, y, test_size=0.15, random_state=SEED ) X_train, X_val, y_train, y_val = train_test_split( X_trainval, y_trainval, test_size=0.15, random_state=SEED ) print(f" Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}") # Tensors X_train_t = torch.tensor(X_train) y_train_t = torch.tensor(y_train) X_val_t = torch.tensor(X_val) y_val_t = torch.tensor(y_val) X_test_t = torch.tensor(X_test) y_test_t = torch.tensor(y_test) # Model n_features = len(feature_names) model = FuzzyNeuralNetwork( n_features=n_features, n_terms=3, hidden_dims=[64, 32], dropout=0.2 ) print(f" Model: FNN with {n_features} inputs, 3 fuzzy terms, 64→32 deep head") total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f" Trainable parameters: {total_params:,}") # Train trainer = FNNTrainer(model, lr=1e-3, weight_decay=1e-4) trainer.fit( X_train_t, y_train_t, X_val_t, y_val_t, epochs=epochs, batch_size=64, patience=50 ) # Evaluate print("\n Test set evaluation:") metrics = evaluate_model(model, X_test_t, y_test_t) for k, v in metrics.items(): print(f" {k}: {v}") # Save os.makedirs(MODEL_DIR, exist_ok=True) model_path = os.path.join(MODEL_DIR, f"fnn_{disaster_type}_model.pt") save_model(model, model_path, feature_names) feat_path = os.path.join(MODEL_DIR, "feature_names", f"{disaster_type}_features.txt") os.makedirs(os.path.dirname(feat_path), exist_ok=True) with open(feat_path, "w") as f: f.write("\n".join(feature_names)) print(f"\n Model saved to: {model_path}") return metrics # ← always returns, never None def train_all(epochs: int = 200): results = {} for disaster_type in DATA_GENERATORS: try: metrics = train_disaster_model(disaster_type, epochs=epochs) if metrics is None: raise RuntimeError("train_disaster_model returned None") results[disaster_type] = metrics except Exception as e: print(f"\n [ERROR] {disaster_type} training failed: {e}") import traceback traceback.print_exc() results[disaster_type] = { "MAE": float("nan"), "AUC-ROC": float("nan"), "Mean Prediction": float("nan"), "Std Prediction": float("nan"), } print("\n" + "="*60) print(" TRAINING SUMMARY") print("="*60) for dt, metrics in results.items(): auc = metrics["AUC-ROC"] mae = metrics["MAE"] auc_str = f"{auc:.4f}" if isinstance(auc, float) and not np.isnan(auc) else "nan" mae_str = f"{mae:.4f}" if isinstance(mae, float) and not np.isnan(mae) else "nan" print(f" {dt.upper():12s} | MAE: {mae_str} | AUC: {auc_str}") print("="*60) # ============================================================================ # ENTRY POINT # ============================================================================ if __name__ == "__main__": parser = argparse.ArgumentParser(description="Train FNN disaster models") parser.add_argument( "--disaster", choices=list(DATA_GENERATORS.keys()) + ["all"], default="all", help="Which disaster model to train" ) parser.add_argument("--epochs", type=int, default=200) parser.add_argument("--samples", type=int, default=None) args = parser.parse_args() if args.disaster == "all": train_all(epochs=args.epochs) else: train_disaster_model(args.disaster, epochs=args.epochs, n_samples=args.samples)