Spaces:
Sleeping
Sleeping
Update src/train_model.py
Browse files- src/train_model.py +170 -83
src/train_model.py
CHANGED
|
@@ -39,7 +39,7 @@ from src.disaster_predictors import (
|
|
| 39 |
FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES
|
| 40 |
)
|
| 41 |
|
| 42 |
-
MODEL_DIR = "models"
|
| 43 |
SEED = 42
|
| 44 |
np.random.seed(SEED)
|
| 45 |
torch.manual_seed(SEED)
|
|
@@ -182,6 +182,7 @@ def generate_cyclone_data(n: int = 3000):
|
|
| 182 |
def nearest_merge(base_df, aux_df, cols):
|
| 183 |
tree = cKDTree(aux_df[["latitude", "longitude"]].values)
|
| 184 |
_, idxs = tree.query(base_df[["latitude", "longitude"]].values)
|
|
|
|
| 185 |
for col in cols:
|
| 186 |
base_df[col] = aux_df[col].iloc[idxs].values
|
| 187 |
return base_df
|
|
@@ -441,37 +442,168 @@ def generate_landslide_data(n: int = 4000):
|
|
| 441 |
return X, y
|
| 442 |
|
| 443 |
def generate_earthquake_data(n: int = 3000):
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
0.15 * liquef_norm +
|
| 464 |
-
0.10 * (1 - depth_norm) + # Shallow = more damage
|
| 465 |
-
0.10 * stress_norm +
|
| 466 |
0.10 * vuln_norm +
|
| 467 |
0.05 * pop_norm +
|
| 468 |
-
0.05 * amp_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
)
|
| 470 |
|
| 471 |
-
|
| 472 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 473 |
|
| 474 |
-
|
|
|
|
| 475 |
|
| 476 |
|
| 477 |
DATA_GENERATORS = {
|
|
@@ -482,6 +614,7 @@ DATA_GENERATORS = {
|
|
| 482 |
}
|
| 483 |
|
| 484 |
|
|
|
|
| 485 |
# ============================================================================
|
| 486 |
# TRAINING PIPELINE
|
| 487 |
# ============================================================================
|
|
@@ -514,65 +647,19 @@ def train_disaster_model(disaster_type: str, epochs: int = 200, n_samples: int =
|
|
| 514 |
print(f"{'='*60}")
|
| 515 |
|
| 516 |
generator_fn, feature_names = DATA_GENERATORS[disaster_type]
|
| 517 |
-
n = n_samples or {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
-
print(f"Loading data (n_samples hint: {n})...")
|
| 520 |
X, y = generator_fn(n)
|
| 521 |
-
|
| 522 |
-
# Train/val/test split
|
| 523 |
-
X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.15, random_state=SEED)
|
| 524 |
-
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.15, random_state=SEED)
|
| 525 |
-
|
| 526 |
-
print(f" Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}")
|
| 527 |
-
|
| 528 |
-
# Tensors
|
| 529 |
-
X_train_t = torch.tensor(X_train)
|
| 530 |
-
y_train_t = torch.tensor(y_train)
|
| 531 |
-
X_val_t = torch.tensor(X_val)
|
| 532 |
-
y_val_t = torch.tensor(y_val)
|
| 533 |
-
X_test_t = torch.tensor(X_test)
|
| 534 |
-
y_test_t = torch.tensor(y_test)
|
| 535 |
-
|
| 536 |
-
# Model
|
| 537 |
-
n_features = len(feature_names)
|
| 538 |
-
model = FuzzyNeuralNetwork(
|
| 539 |
-
n_features=n_features,
|
| 540 |
-
n_terms=3,
|
| 541 |
-
hidden_dims=[64, 32],
|
| 542 |
-
dropout=0.2
|
| 543 |
-
)
|
| 544 |
-
|
| 545 |
-
print(f" Model: FNN with {n_features} inputs, 3 fuzzy terms, 64β32 deep head")
|
| 546 |
-
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 547 |
-
print(f" Trainable parameters: {total_params:,}")
|
| 548 |
-
|
| 549 |
-
# Train
|
| 550 |
-
trainer = FNNTrainer(model, lr=1e-3, weight_decay=1e-4)
|
| 551 |
-
history = trainer.fit(
|
| 552 |
-
X_train_t, y_train_t,
|
| 553 |
-
X_val_t, y_val_t,
|
| 554 |
-
epochs=epochs, batch_size=64, patience=25
|
| 555 |
-
)
|
| 556 |
-
|
| 557 |
-
# Evaluate
|
| 558 |
-
print("\n Test set evaluation:")
|
| 559 |
-
metrics = evaluate_model(model, X_test_t, y_test_t)
|
| 560 |
-
for k, v in metrics.items():
|
| 561 |
-
print(f" {k}: {v}")
|
| 562 |
-
|
| 563 |
-
# Save
|
| 564 |
-
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 565 |
-
model_path = os.path.join(MODEL_DIR, f"fnn_{disaster_type}_model.pt")
|
| 566 |
-
save_model(model, model_path, feature_names)
|
| 567 |
-
|
| 568 |
-
# Save feature names as text too
|
| 569 |
-
feat_path = os.path.join(MODEL_DIR, "feature_names", f"{disaster_type}_features.txt")
|
| 570 |
-
os.makedirs(os.path.dirname(feat_path), exist_ok=True)
|
| 571 |
-
with open(feat_path, "w") as f:
|
| 572 |
-
f.write("\n".join(feature_names))
|
| 573 |
-
|
| 574 |
-
print(f"\n Model saved to: {model_path}")
|
| 575 |
-
return metrics
|
| 576 |
|
| 577 |
|
| 578 |
def train_all(epochs: int = 200):
|
|
|
|
| 39 |
FLOOD_FEATURES, CYCLONE_FEATURES, LANDSLIDE_FEATURES, EARTHQUAKE_FEATURES
|
| 40 |
)
|
| 41 |
|
| 42 |
+
MODEL_DIR = os.path.join(BASE_DIR, "models")
|
| 43 |
SEED = 42
|
| 44 |
np.random.seed(SEED)
|
| 45 |
torch.manual_seed(SEED)
|
|
|
|
| 182 |
def nearest_merge(base_df, aux_df, cols):
|
| 183 |
tree = cKDTree(aux_df[["latitude", "longitude"]].values)
|
| 184 |
_, idxs = tree.query(base_df[["latitude", "longitude"]].values)
|
| 185 |
+
base_df = base_df.copy()
|
| 186 |
for col in cols:
|
| 187 |
base_df[col] = aux_df[col].iloc[idxs].values
|
| 188 |
return base_df
|
|
|
|
| 442 |
return X, y
|
| 443 |
|
| 444 |
def generate_earthquake_data(n: int = 3000):
|
| 445 |
+
"""
|
| 446 |
+
Loads and joins earthquake datasets:
|
| 447 |
+
- earthquake_history.csv (spine + historical_seismicity, focal_depth_km, tectonic_stress_index)
|
| 448 |
+
- fault_lines_earthquake.csv (distance_to_fault_km, seismic_hazard_index)
|
| 449 |
+
- soil_liquefaction.csv (soil_liquefaction_index)
|
| 450 |
+
- vs30_bedrock.csv (bedrock_amplification)
|
| 451 |
+
- building_vulnerability.csv (building_vulnerability)
|
| 452 |
+
- population_earthquake.csv (population_density_norm)
|
| 453 |
+
"""
|
| 454 |
+
print("[Earthquake] Using REAL data loader")
|
| 455 |
+
|
| 456 |
+
def nearest_merge(base_df, aux_df, cols,
|
| 457 |
+
base_lat="latitude", base_lon="longitude",
|
| 458 |
+
aux_lat="latitude", aux_lon="longitude"):
|
| 459 |
+
if aux_lat not in aux_df.columns or aux_lon not in aux_df.columns:
|
| 460 |
+
raise ValueError(
|
| 461 |
+
f"nearest_merge: aux_df missing lat/lon. Has: {list(aux_df.columns)}"
|
| 462 |
+
)
|
| 463 |
+
tree = cKDTree(aux_df[[aux_lat, aux_lon]].values)
|
| 464 |
+
_, idxs = tree.query(base_df[[base_lat, base_lon]].values)
|
| 465 |
+
base_df = base_df.copy()
|
| 466 |
+
for col in cols:
|
| 467 |
+
if col not in aux_df.columns:
|
| 468 |
+
raise ValueError(
|
| 469 |
+
f"nearest_merge: '{col}' not in aux_df. "
|
| 470 |
+
f"Has: {list(aux_df.columns)}"
|
| 471 |
+
)
|
| 472 |
+
base_df[col] = aux_df[col].iloc[idxs].values
|
| 473 |
+
return base_df
|
| 474 |
+
|
| 475 |
+
# ββ Load ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 476 |
+
print("[Earthquake] Loading CSVs...")
|
| 477 |
+
history = pd.read_csv(os.path.join(DATA_DIR, "earthquake_history.csv"))
|
| 478 |
+
faults = pd.read_csv(os.path.join(DATA_DIR, "fault_lines_earthquake.csv"))
|
| 479 |
+
liquef = pd.read_csv(os.path.join(DATA_DIR, "soil_liquefaction.csv"))
|
| 480 |
+
vs30 = pd.read_csv(os.path.join(DATA_DIR, "vs30_bedrock.csv"))
|
| 481 |
+
bldg = pd.read_csv(os.path.join(DATA_DIR, "building_vulnerability.csv"))
|
| 482 |
+
pop = pd.read_csv(os.path.join(DATA_DIR, "population_earthquake.csv"))
|
| 483 |
+
|
| 484 |
+
for df in (history, faults, liquef, vs30, bldg, pop):
|
| 485 |
+
df.columns = df.columns.str.lower().str.strip()
|
| 486 |
+
|
| 487 |
+
print(f"[Earthquake] History: {len(history)} rows, cols: {list(history.columns)}")
|
| 488 |
+
print(f"[Earthquake] Faults cols: {list(faults.columns)}")
|
| 489 |
+
print(f"[Earthquake] Liquef cols: {list(liquef.columns)}")
|
| 490 |
+
print(f"[Earthquake] VS30 cols: {list(vs30.columns)}")
|
| 491 |
+
print(f"[Earthquake] Bldg cols: {list(bldg.columns)}")
|
| 492 |
+
print(f"[Earthquake] Pop cols: {list(pop.columns)}")
|
| 493 |
+
|
| 494 |
+
# ββ Clean history spine βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 495 |
+
history = history.dropna(subset=["latitude", "longitude"])
|
| 496 |
+
history["date"] = pd.to_datetime(history["date"], errors="coerce")
|
| 497 |
+
history = history.dropna(subset=["date"])
|
| 498 |
+
print(f"[Earthquake] After date clean: {len(history)} rows")
|
| 499 |
+
|
| 500 |
+
if len(history) == 0:
|
| 501 |
+
raise ValueError(
|
| 502 |
+
"earthquake_history has 0 rows after date parsing. "
|
| 503 |
+
f"Sample raw dates: {pd.read_csv(os.path.join(DATA_DIR, 'earthquake_history.csv'))['date'].head().tolist()}"
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
base = history.copy()
|
| 507 |
+
|
| 508 |
+
# ββ Fault lines β distance_to_fault_km βββββββββββββββββββββββββββββββ
|
| 509 |
+
# fault_lines_earthquake already has distance_to_fault_km as a column
|
| 510 |
+
# but we still spatial-join to get the nearest fault's values
|
| 511 |
+
print("[Earthquake] Merging fault lines...")
|
| 512 |
+
base = nearest_merge(base, faults, ["distance_to_fault_km"])
|
| 513 |
+
|
| 514 |
+
# ββ Soil liquefaction βββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 515 |
+
print("[Earthquake] Merging soil liquefaction...")
|
| 516 |
+
base = nearest_merge(base, liquef, ["soil_liquefaction_index"])
|
| 517 |
+
|
| 518 |
+
# ββ VS30 / bedrock amplification ββββββββββββββββββββββββββββββββββββββ
|
| 519 |
+
print("[Earthquake] Merging VS30 bedrock...")
|
| 520 |
+
base = nearest_merge(base, vs30, ["bedrock_amplification"])
|
| 521 |
+
|
| 522 |
+
# ββ Building vulnerability ββββββββββββββββββββββββββββββββββββββββββββ
|
| 523 |
+
print("[Earthquake] Merging building vulnerability...")
|
| 524 |
+
base = nearest_merge(base, bldg, ["building_vulnerability"])
|
| 525 |
+
|
| 526 |
+
# ββ Population density ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 527 |
+
print("[Earthquake] Merging population...")
|
| 528 |
+
base = nearest_merge(base, pop, ["population_density_norm"])
|
| 529 |
+
|
| 530 |
+
# ββ Validate all required columns present βββββββββββββββββββββββββββββ
|
| 531 |
+
required = [
|
| 532 |
+
"historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index",
|
| 533 |
+
"focal_depth_km", "tectonic_stress_index", "building_vulnerability",
|
| 534 |
+
"population_density_norm", "bedrock_amplification",
|
| 535 |
+
]
|
| 536 |
+
missing = [c for c in required if c not in base.columns]
|
| 537 |
+
if missing:
|
| 538 |
+
raise ValueError(
|
| 539 |
+
f"Missing columns after all merges: {missing}\n"
|
| 540 |
+
f"Available: {list(base.columns)}"
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
base = base.dropna(subset=required)
|
| 544 |
+
print(f"[Earthquake] Rows after dropna: {len(base)}")
|
| 545 |
+
|
| 546 |
+
if len(base) < 50:
|
| 547 |
+
raise ValueError(
|
| 548 |
+
f"Only {len(base)} clean rows β check CSV paths and column names"
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
# ββ Risk label ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 552 |
+
# Use magnitude if available, otherwise derive from features
|
| 553 |
+
if "magnitude" in base.columns:
|
| 554 |
+
base["magnitude"] = pd.to_numeric(base["magnitude"], errors="coerce").fillna(0)
|
| 555 |
+
mag_norm = np.clip((base["magnitude"] - 2.0) / 7.0, 0, 1) # scale 2β9
|
| 556 |
+
else:
|
| 557 |
+
mag_norm = pd.Series(np.zeros(len(base)))
|
| 558 |
+
|
| 559 |
+
depth_norm = np.clip(base["focal_depth_km"] / 700.0, 0, 1)
|
| 560 |
+
fault_norm = np.clip(base["distance_to_fault_km"] / 200.0, 0, 1)
|
| 561 |
+
liquef_norm = np.clip(base["soil_liquefaction_index"], 0, 1)
|
| 562 |
+
vuln_norm = np.clip(base["building_vulnerability"], 0, 1)
|
| 563 |
+
pop_norm = np.clip(base["population_density_norm"], 0, 1)
|
| 564 |
+
amp_norm = np.clip(base["bedrock_amplification"], 0, 1)
|
| 565 |
+
stress_norm = np.clip(base["tectonic_stress_index"], 0, 1)
|
| 566 |
+
seism_norm = np.clip(base["historical_seismicity"], 0, 1)
|
| 567 |
+
|
| 568 |
+
base["risk_score"] = np.clip(
|
| 569 |
+
0.25 * mag_norm.values +
|
| 570 |
+
0.20 * (1 - depth_norm) + # shallow = more damage
|
| 571 |
+
0.15 * (1 - fault_norm) + # close to fault = more risk
|
| 572 |
0.15 * liquef_norm +
|
|
|
|
|
|
|
| 573 |
0.10 * vuln_norm +
|
| 574 |
0.05 * pop_norm +
|
| 575 |
+
0.05 * amp_norm +
|
| 576 |
+
0.05 * seism_norm +
|
| 577 |
+
np.random.normal(0, 0.02, len(base)),
|
| 578 |
+
0.0, 1.0
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
print(f"[Earthquake] Risk score: mean={base['risk_score'].mean():.3f}, "
|
| 582 |
+
f"std={base['risk_score'].std():.3f}, "
|
| 583 |
+
f">0.5: {(base['risk_score'] > 0.5).sum()} rows")
|
| 584 |
+
|
| 585 |
+
# ββ Normalise features ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 586 |
+
features = [
|
| 587 |
+
"historical_seismicity", "distance_to_fault_km", "soil_liquefaction_index",
|
| 588 |
+
"focal_depth_km", "tectonic_stress_index", "building_vulnerability",
|
| 589 |
+
"population_density_norm", "bedrock_amplification",
|
| 590 |
+
]
|
| 591 |
+
|
| 592 |
+
assert features == list(EARTHQUAKE_FEATURES), (
|
| 593 |
+
f"Feature mismatch!\n train: {features}\n"
|
| 594 |
+
f" predictor: {list(EARTHQUAKE_FEATURES)}"
|
| 595 |
)
|
| 596 |
|
| 597 |
+
from src.disaster_predictors import FEATURE_RANGES
|
| 598 |
+
X = np.zeros((len(base), len(features)), dtype=np.float32)
|
| 599 |
+
for i, feat in enumerate(features):
|
| 600 |
+
lo, hi = FEATURE_RANGES[feat]
|
| 601 |
+
X[:, i] = np.clip(
|
| 602 |
+
(base[feat].values - lo) / (hi - lo + 1e-8), 0, 1
|
| 603 |
+
)
|
| 604 |
|
| 605 |
+
y = base["risk_score"].values.astype(np.float32)
|
| 606 |
+
return X, y
|
| 607 |
|
| 608 |
|
| 609 |
DATA_GENERATORS = {
|
|
|
|
| 614 |
}
|
| 615 |
|
| 616 |
|
| 617 |
+
|
| 618 |
# ============================================================================
|
| 619 |
# TRAINING PIPELINE
|
| 620 |
# ============================================================================
|
|
|
|
| 647 |
print(f"{'='*60}")
|
| 648 |
|
| 649 |
generator_fn, feature_names = DATA_GENERATORS[disaster_type]
|
| 650 |
+
n = n_samples or {
|
| 651 |
+
"flood": 5000, "cyclone": 3000,
|
| 652 |
+
"landslide": 4000, "earthquake": 3000
|
| 653 |
+
}[disaster_type]
|
| 654 |
+
|
| 655 |
+
REAL_DATA_GENERATORS = {"flood", "cyclone", "landslide", "earthquake"}
|
| 656 |
+
if disaster_type in REAL_DATA_GENERATORS:
|
| 657 |
+
print(f"Loading real data for {disaster_type}...")
|
| 658 |
+
else:
|
| 659 |
+
print(f"Generating {n} synthetic samples...")
|
| 660 |
|
|
|
|
| 661 |
X, y = generator_fn(n)
|
| 662 |
+
# ... rest unchanged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
|
| 664 |
|
| 665 |
def train_all(epochs: int = 200):
|