clarindasusan commited on
Commit
349d88b
Β·
verified Β·
1 Parent(s): d5dcb2d

Update src/train_model.py

Browse files
Files changed (1) hide show
  1. src/train_model.py +170 -83
src/train_model.py CHANGED
@@ -39,7 +39,7 @@ from src.disaster_predictors import (
39
  FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES
40
  )
41
 
42
- MODEL_DIR = "models"
43
  SEED = 42
44
  np.random.seed(SEED)
45
  torch.manual_seed(SEED)
@@ -182,6 +182,7 @@ def generate_cyclone_data(n: int = 3000):
182
  def nearest_merge(base_df, aux_df, cols):
183
  tree = cKDTree(aux_df[["latitude", "longitude"]].values)
184
  _, idxs = tree.query(base_df[["latitude", "longitude"]].values)
 
185
  for col in cols:
186
  base_df[col] = aux_df[col].iloc[idxs].values
187
  return base_df
@@ -441,37 +442,168 @@ def generate_landslide_data(n: int = 4000):
441
  return X, y
442
 
443
  def generate_earthquake_data(n: int = 3000):
444
- rng = np.random.default_rng(SEED + 3)
445
-
446
- hist_seism_norm = rng.beta(2, 4, n)
447
- fault_norm = rng.beta(2, 2, n) # Higher = farther from fault
448
- liquef_norm = rng.beta(2, 4, n)
449
- depth_norm = rng.beta(3, 2, n) # Higher = deeper = less damage
450
- stress_norm = rng.beta(2, 3, n)
451
- vuln_norm = rng.beta(2, 3, n)
452
- pop_norm = rng.beta(2, 2, n)
453
- amp_norm = rng.beta(2, 3, n)
454
-
455
- X = np.column_stack([
456
- hist_seism_norm, fault_norm, liquef_norm, depth_norm,
457
- stress_norm, vuln_norm, pop_norm, amp_norm
458
- ])
459
-
460
- risk = (
461
- 0.25 * hist_seism_norm +
462
- 0.20 * (1 - fault_norm) + # Close to fault = more risk
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  0.15 * liquef_norm +
464
- 0.10 * (1 - depth_norm) + # Shallow = more damage
465
- 0.10 * stress_norm +
466
  0.10 * vuln_norm +
467
  0.05 * pop_norm +
468
- 0.05 * amp_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
469
  )
470
 
471
- risk += rng.normal(0, 0.05, n)
472
- y = np.clip(risk, 0.0, 1.0).astype(np.float32)
 
 
 
 
 
473
 
474
- return X.astype(np.float32), y
 
475
 
476
 
477
  DATA_GENERATORS = {
@@ -482,6 +614,7 @@ DATA_GENERATORS = {
482
  }
483
 
484
 
 
485
  # ============================================================================
486
  # TRAINING PIPELINE
487
  # ============================================================================
@@ -514,65 +647,19 @@ def train_disaster_model(disaster_type: str, epochs: int = 200, n_samples: int =
514
  print(f"{'='*60}")
515
 
516
  generator_fn, feature_names = DATA_GENERATORS[disaster_type]
517
- n = n_samples or {"flood": 5000, "cyclone": 3000, "landslide": 4000, "earthquake": 3000}[disaster_type]
 
 
 
 
 
 
 
 
 
518
 
519
- print(f"Loading data (n_samples hint: {n})...")
520
  X, y = generator_fn(n)
521
-
522
- # Train/val/test split
523
- X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.15, random_state=SEED)
524
- X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.15, random_state=SEED)
525
-
526
- print(f" Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")
527
-
528
- # Tensors
529
- X_train_t = torch.tensor(X_train)
530
- y_train_t = torch.tensor(y_train)
531
- X_val_t = torch.tensor(X_val)
532
- y_val_t = torch.tensor(y_val)
533
- X_test_t = torch.tensor(X_test)
534
- y_test_t = torch.tensor(y_test)
535
-
536
- # Model
537
- n_features = len(feature_names)
538
- model = FuzzyNeuralNetwork(
539
- n_features=n_features,
540
- n_terms=3,
541
- hidden_dims=[64, 32],
542
- dropout=0.2
543
- )
544
-
545
- print(f" Model: FNN with {n_features} inputs, 3 fuzzy terms, 64β†’32 deep head")
546
- total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
547
- print(f" Trainable parameters: {total_params:,}")
548
-
549
- # Train
550
- trainer = FNNTrainer(model, lr=1e-3, weight_decay=1e-4)
551
- history = trainer.fit(
552
- X_train_t, y_train_t,
553
- X_val_t, y_val_t,
554
- epochs=epochs, batch_size=64, patience=25
555
- )
556
-
557
- # Evaluate
558
- print("\n Test set evaluation:")
559
- metrics = evaluate_model(model, X_test_t, y_test_t)
560
- for k, v in metrics.items():
561
- print(f" {k}: {v}")
562
-
563
- # Save
564
- os.makedirs(MODEL_DIR, exist_ok=True)
565
- model_path = os.path.join(MODEL_DIR, f"fnn_{disaster_type}_model.pt")
566
- save_model(model, model_path, feature_names)
567
-
568
- # Save feature names as text too
569
- feat_path = os.path.join(MODEL_DIR, "feature_names", f"{disaster_type}_features.txt")
570
- os.makedirs(os.path.dirname(feat_path), exist_ok=True)
571
- with open(feat_path, "w") as f:
572
- f.write("\n".join(feature_names))
573
-
574
- print(f"\n Model saved to: {model_path}")
575
- return metrics
576
 
577
 
578
  def train_all(epochs: int = 200):
 
39
  FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES
40
  )
41
 
42
+ MODEL_DIR = os.path.join(BASE_DIR, "models")
43
  SEED = 42
44
  np.random.seed(SEED)
45
  torch.manual_seed(SEED)
 
182
  def nearest_merge(base_df, aux_df, cols):
183
  tree = cKDTree(aux_df[["latitude", "longitude"]].values)
184
  _, idxs = tree.query(base_df[["latitude", "longitude"]].values)
185
+ base_df = base_df.copy()
186
  for col in cols:
187
  base_df[col] = aux_df[col].iloc[idxs].values
188
  return base_df
 
442
  return X, y
443
 
444
  def generate_earthquake_data(n: int = 3000):
445
+ """
446
+ Loads and joins earthquake datasets:
447
+ - earthquake_history.csv (spine + historical_seismicity, focal_depth_km, tectonic_stress_index)
448
+ - fault_lines_earthquake.csv (distance_to_fault_km, seismic_hazard_index)
449
+ - soil_liquefaction.csv (soil_liquefaction_index)
450
+ - vs30_bedrock.csv (bedrock_amplification)
451
+ - building_vulnerability.csv (building_vulnerability)
452
+ - population_earthquake.csv (population_density_norm)
453
+ """
454
+ print("[Earthquake] Using REAL data loader")
455
+
456
+ def nearest_merge(base_df, aux_df, cols,
457
+ base_lat="latitude", base_lon="longitude",
458
+ aux_lat="latitude", aux_lon="longitude"):
459
+ if aux_lat not in aux_df.columns or aux_lon not in aux_df.columns:
460
+ raise ValueError(
461
+ f"nearest_merge: aux_df missing lat/lon. Has: {list(aux_df.columns)}"
462
+ )
463
+ tree = cKDTree(aux_df[[aux_lat, aux_lon]].values)
464
+ _, idxs = tree.query(base_df[[base_lat, base_lon]].values)
465
+ base_df = base_df.copy()
466
+ for col in cols:
467
+ if col not in aux_df.columns:
468
+ raise ValueError(
469
+ f"nearest_merge: '{col}' not in aux_df. "
470
+ f"Has: {list(aux_df.columns)}"
471
+ )
472
+ base_df[col] = aux_df[col].iloc[idxs].values
473
+ return base_df
474
+
475
+ # ── Load ──────────────────────────────────────────────────────────────
476
+ print("[Earthquake] Loading CSVs...")
477
+ history = pd.read_csv(os.path.join(DATA_DIR, "earthquake_history.csv"))
478
+ faults = pd.read_csv(os.path.join(DATA_DIR, "fault_lines_earthquake.csv"))
479
+ liquef = pd.read_csv(os.path.join(DATA_DIR, "soil_liquefaction.csv"))
480
+ vs30 = pd.read_csv(os.path.join(DATA_DIR, "vs30_bedrock.csv"))
481
+ bldg = pd.read_csv(os.path.join(DATA_DIR, "building_vulnerability.csv"))
482
+ pop = pd.read_csv(os.path.join(DATA_DIR, "population_earthquake.csv"))
483
+
484
+ for df in (history, faults, liquef, vs30, bldg, pop):
485
+ df.columns = df.columns.str.lower().str.strip()
486
+
487
+ print(f"[Earthquake] History: {len(history)} rows, cols: {list(history.columns)}")
488
+ print(f"[Earthquake] Faults cols: {list(faults.columns)}")
489
+ print(f"[Earthquake] Liquef cols: {list(liquef.columns)}")
490
+ print(f"[Earthquake] VS30 cols: {list(vs30.columns)}")
491
+ print(f"[Earthquake] Bldg cols: {list(bldg.columns)}")
492
+ print(f"[Earthquake] Pop cols: {list(pop.columns)}")
493
+
494
+ # ── Clean history spine ───────────────────────────────────────────────
495
+ history = history.dropna(subset=["latitude", "longitude"])
496
+ history["date"] = pd.to_datetime(history["date"], errors="coerce")
497
+ history = history.dropna(subset=["date"])
498
+ print(f"[Earthquake] After date clean: {len(history)} rows")
499
+
500
+ if len(history) == 0:
501
+ raise ValueError(
502
+ "earthquake_history has 0 rows after date parsing. "
503
+ f"Sample raw dates: {pd.read_csv(os.path.join(DATA_DIR, 'earthquake_history.csv'))['date'].head().tolist()}"
504
+ )
505
+
506
+ base = history.copy()
507
+
508
+ # ── Fault lines β†’ distance_to_fault_km ───────────────────────────────
509
+ # fault_lines_earthquake already has distance_to_fault_km as a column
510
+ # but we still spatial-join to get the nearest fault's values
511
+ print("[Earthquake] Merging fault lines...")
512
+ base = nearest_merge(base, faults, ["distance_to_fault_km"])
513
+
514
+ # ── Soil liquefaction ─────────────────────────────────────────────────
515
+ print("[Earthquake] Merging soil liquefaction...")
516
+ base = nearest_merge(base, liquef, ["soil_liquefaction_index"])
517
+
518
+ # ── VS30 / bedrock amplification ──────────────────────────────────────
519
+ print("[Earthquake] Merging VS30 bedrock...")
520
+ base = nearest_merge(base, vs30, ["bedrock_amplification"])
521
+
522
+ # ── Building vulnerability ────────────────────────────────────────────
523
+ print("[Earthquake] Merging building vulnerability...")
524
+ base = nearest_merge(base, bldg, ["building_vulnerability"])
525
+
526
+ # ── Population density ────────────────────────────────────────────────
527
+ print("[Earthquake] Merging population...")
528
+ base = nearest_merge(base, pop, ["population_density_norm"])
529
+
530
+ # ── Validate all required columns present ─────────────────────────────
531
+ required = [
532
+ "historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index",
533
+ "focal_depth_km", "tectonic_stress_index", "building_vulnerability",
534
+ "population_density_norm", "bedrock_amplification",
535
+ ]
536
+ missing = [c for c in required if c not in base.columns]
537
+ if missing:
538
+ raise ValueError(
539
+ f"Missing columns after all merges: {missing}\n"
540
+ f"Available: {list(base.columns)}"
541
+ )
542
+
543
+ base = base.dropna(subset=required)
544
+ print(f"[Earthquake] Rows after dropna: {len(base)}")
545
+
546
+ if len(base) < 50:
547
+ raise ValueError(
548
+ f"Only {len(base)} clean rows β€” check CSV paths and column names"
549
+ )
550
+
551
+ # ── Risk label ────────────────────────────────────────────────────────
552
+ # Use magnitude if available, otherwise derive from features
553
+ if "magnitude" in base.columns:
554
+ base["magnitude"] = pd.to_numeric(base["magnitude"], errors="coerce").fillna(0)
555
+ mag_norm = np.clip((base["magnitude"] - 2.0) / 7.0, 0, 1) # scale 2–9
556
+ else:
557
+ mag_norm = pd.Series(np.zeros(len(base)))
558
+
559
+ depth_norm = np.clip(base["focal_depth_km"] / 700.0, 0, 1)
560
+ fault_norm = np.clip(base["distance_to_fault_km"] / 200.0, 0, 1)
561
+ liquef_norm = np.clip(base["soil_liquefaction_index"], 0, 1)
562
+ vuln_norm = np.clip(base["building_vulnerability"], 0, 1)
563
+ pop_norm = np.clip(base["population_density_norm"], 0, 1)
564
+ amp_norm = np.clip(base["bedrock_amplification"], 0, 1)
565
+ stress_norm = np.clip(base["tectonic_stress_index"], 0, 1)
566
+ seism_norm = np.clip(base["historical_seismicity"], 0, 1)
567
+
568
+ base["risk_score"] = np.clip(
569
+ 0.25 * mag_norm.values +
570
+ 0.20 * (1 - depth_norm) + # shallow = more damage
571
+ 0.15 * (1 - fault_norm) + # close to fault = more risk
572
  0.15 * liquef_norm +
 
 
573
  0.10 * vuln_norm +
574
  0.05 * pop_norm +
575
+ 0.05 * amp_norm +
576
+ 0.05 * seism_norm +
577
+ np.random.normal(0, 0.02, len(base)),
578
+ 0.0, 1.0
579
+ )
580
+
581
+ print(f"[Earthquake] Risk score: mean={base['risk_score'].mean():.3f}, "
582
+ f"std={base['risk_score'].std():.3f}, "
583
+ f">0.5: {(base['risk_score'] > 0.5).sum()} rows")
584
+
585
+ # ── Normalise features ────────────────────────────────────────────────
586
+ features = [
587
+ "historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index",
588
+ "focal_depth_km", "tectonic_stress_index", "building_vulnerability",
589
+ "population_density_norm", "bedrock_amplification",
590
+ ]
591
+
592
+ assert features == list(EARTHQUAKE_FEATURES), (
593
+ f"Feature mismatch!\n train: {features}\n"
594
+ f" predictor: {list(EARTHQUAKE_FEATURES)}"
595
  )
596
 
597
+ from src.disaster_predictors import FEATURE_RANGES
598
+ X = np.zeros((len(base), len(features)), dtype=np.float32)
599
+ for i, feat in enumerate(features):
600
+ lo, hi = FEATURE_RANGES[feat]
601
+ X[:, i] = np.clip(
602
+ (base[feat].values - lo) / (hi - lo + 1e-8), 0, 1
603
+ )
604
 
605
+ y = base["risk_score"].values.astype(np.float32)
606
+ return X, y
607
 
608
 
609
  DATA_GENERATORS = {
 
614
  }
615
 
616
 
617
+
618
  # ============================================================================
619
  # TRAINING PIPELINE
620
  # ============================================================================
 
647
  print(f"{'='*60}")
648
 
649
  generator_fn, feature_names = DATA_GENERATORS[disaster_type]
650
+ n = n_samples or {
651
+ "flood": 5000, "cyclone": 3000,
652
+ "landslide": 4000, "earthquake": 3000
653
+ }[disaster_type]
654
+
655
+ REAL_DATA_GENERATORS = {"flood", "cyclone", "landslide", "earthquake"}
656
+ if disaster_type in REAL_DATA_GENERATORS:
657
+ print(f"Loading real data for {disaster_type}...")
658
+ else:
659
+ print(f"Generating {n} synthetic samples...")
660
 
 
661
  X, y = generator_fn(n)
662
+ # ... rest unchanged
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
663
 
664
 
665
  def train_all(epochs: int = 200):