Spaces:
Sleeping
Sleeping
Update src/train_model.py
Browse files- src/train_model.py +18 -9
src/train_model.py
CHANGED
|
@@ -588,6 +588,11 @@ def generate_earthquake_data(n: int = 3000):
|
|
| 588 |
"focal_depth_km", "tectonic_stress_index", "building_vulnerability",
|
| 589 |
"population_density_norm", "bedrock_amplification",
|
| 590 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 591 |
|
| 592 |
assert features == list(EARTHQUAKE_FEATURES), (
|
| 593 |
f"Feature mismatch!\n train: {features}\n"
|
|
@@ -622,22 +627,24 @@ DATA_GENERATORS = {
|
|
| 622 |
def evaluate_model(model: FuzzyNeuralNetwork, X: torch.Tensor, y: torch.Tensor) -> dict:
|
| 623 |
model.eval()
|
| 624 |
with torch.no_grad():
|
| 625 |
-
preds = model(X).numpy()
|
| 626 |
y_np = y.numpy()
|
| 627 |
|
| 628 |
-
|
| 629 |
try:
|
| 630 |
-
auc = roc_auc_score((y_np >
|
| 631 |
-
except
|
| 632 |
-
auc = float(
|
|
|
|
| 633 |
|
| 634 |
mae = mean_absolute_error(y_np, preds)
|
| 635 |
|
| 636 |
return {
|
| 637 |
-
"MAE":
|
| 638 |
-
"AUC-ROC":
|
| 639 |
"Mean Prediction": round(float(preds.mean()), 4),
|
| 640 |
-
"
|
|
|
|
| 641 |
}
|
| 642 |
|
| 643 |
|
|
@@ -672,7 +679,9 @@ def train_all(epochs: int = 200):
|
|
| 672 |
print(" TRAINING SUMMARY")
|
| 673 |
print("="*60)
|
| 674 |
for dt, metrics in results.items():
|
| 675 |
-
|
|
|
|
|
|
|
| 676 |
print("="*60)
|
| 677 |
|
| 678 |
|
|
|
|
| 588 |
"focal_depth_km", "tectonic_stress_index", "building_vulnerability",
|
| 589 |
"population_density_norm", "bedrock_amplification",
|
| 590 |
]
|
| 591 |
+
# Add immediately before the assert
|
| 592 |
+
print(f"[Earthquake] Columns available before assert: {list(base.columns)}")
|
| 593 |
+
print(f"[Earthquake] Required: {required}")
|
| 594 |
+
print(f"[Earthquake] Missing: {[c for c in required if c not in base.columns]}")
|
| 595 |
+
|
| 596 |
|
| 597 |
assert features == list(EARTHQUAKE_FEATURES), (
|
| 598 |
f"Feature mismatch!\n train: {features}\n"
|
|
|
|
| 627 |
def evaluate_model(model: FuzzyNeuralNetwork, X: torch.Tensor, y: torch.Tensor) -> dict:
|
| 628 |
model.eval()
|
| 629 |
with torch.no_grad():
|
| 630 |
+
preds = model(X).numpy().flatten()
|
| 631 |
y_np = y.numpy()
|
| 632 |
|
| 633 |
+
threshold = float(np.median(y_np))
|
| 634 |
try:
|
| 635 |
+
auc = roc_auc_score((y_np > threshold).astype(int), preds)
|
| 636 |
+
except ValueError as e:
|
| 637 |
+
auc = float("nan")
|
| 638 |
+
print(f" [Warning] AUC undefined: {e}")
|
| 639 |
|
| 640 |
mae = mean_absolute_error(y_np, preds)
|
| 641 |
|
| 642 |
return {
|
| 643 |
+
"MAE": round(float(mae), 4),
|
| 644 |
+
"AUC-ROC": round(float(auc), 4) if not np.isnan(auc) else "nan",
|
| 645 |
"Mean Prediction": round(float(preds.mean()), 4),
|
| 646 |
+
"Mean Label": round(float(y_np.mean()), 4),
|
| 647 |
+
"Std Prediction": round(float(preds.std()), 4),
|
| 648 |
}
|
| 649 |
|
| 650 |
|
|
|
|
| 679 |
print(" TRAINING SUMMARY")
|
| 680 |
print("="*60)
|
| 681 |
for dt, metrics in results.items():
|
| 682 |
+
auc = metrics["AUC-ROC"]
|
| 683 |
+
auc_str = f"{auc:.4f}" if isinstance(auc, float) and not np.isnan(auc) else str(auc)
|
| 684 |
+
print(f" {dt.upper():12s} | MAE: {metrics['MAE']:.4f} | AUC: {auc_str}")
|
| 685 |
print("="*60)
|
| 686 |
|
| 687 |
|