fea-surrogate / src /training /evaluate.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Evaluation pipeline for the structural analysis surrogate.
Computes R2, MAPE, max error per problem family, calibration metrics,
and generates comparison tables and plots.
Usage:
python -m src.training.evaluate --config configs/training.yaml
"""
import argparse
import json
import logging
from pathlib import Path
import numpy as np
import torch
import yaml
from sklearn.metrics import r2_score
from src.models.ensemble import DeepEnsemble
from src.models.normalization import LogTransformStandardizer
from src.training.dataset import create_dataloaders
from src.utils.device import get_device
logger = logging.getLogger(__name__)
def evaluate_ensemble(
ensemble: DeepEnsemble,
test_loader: torch.utils.data.DataLoader,
device: torch.device,
) -> dict:
"""Run evaluation on test set.
Returns:
Dict with metrics per output (stress, deflection) and overall.
"""
ensemble = ensemble.to(device)
ensemble.eval()
all_stress_pred = []
all_stress_true = []
all_defl_pred = []
all_defl_true = []
all_stress_std = []
all_defl_std = []
all_safety_pred = []
all_safety_true = []
with torch.no_grad():
for X_batch, targets in test_loader:
X_batch = X_batch.to(device)
out = ensemble(X_batch)
all_stress_pred.append(out["stress_mean"].cpu().numpy())
all_defl_pred.append(out["deflection_mean"].cpu().numpy())
all_stress_std.append(torch.sqrt(out["stress_var"]).cpu().numpy())
all_defl_std.append(torch.sqrt(out["deflection_var"]).cpu().numpy())
all_stress_true.append(targets["log_stress"].numpy())
all_defl_true.append(targets["log_deflection"].numpy())
all_safety_pred.append(out["safety"].argmax(dim=1).cpu().numpy())
all_safety_true.append(targets["safety_class"].numpy())
# Concatenate
stress_pred = np.concatenate(all_stress_pred)
stress_true = np.concatenate(all_stress_true)
defl_pred = np.concatenate(all_defl_pred)
defl_true = np.concatenate(all_defl_true)
stress_std = np.concatenate(all_stress_std)
defl_std = np.concatenate(all_defl_std)
safety_pred = np.concatenate(all_safety_pred)
safety_true = np.concatenate(all_safety_true)
# Metrics in log-space (predictions are in log10)
metrics = {}
for name, pred, true, std in [
("stress", stress_pred, stress_true, stress_std),
("deflection", defl_pred, defl_true, defl_std),
]:
r2 = r2_score(true, pred)
# MAPE in original space: |10^pred - 10^true| / 10^true * 100
pred_orig = 10.0 ** pred
true_orig = 10.0 ** true
mape = np.mean(np.abs(pred_orig - true_orig) / true_orig) * 100.0
# Max absolute percentage error
max_ape = np.max(np.abs(pred_orig - true_orig) / true_orig) * 100.0
# RMSE in log-space
rmse_log = np.sqrt(np.mean((pred - true) ** 2))
# Calibration: what fraction of test points fall within predicted 95% CI?
z95 = 1.96
lower = pred - z95 * std
upper = pred + z95 * std
coverage_95 = np.mean((true >= lower) & (true <= upper)) * 100.0
metrics[name] = {
"r2": float(r2),
"mape_percent": float(mape),
"max_ape_percent": float(max_ape),
"rmse_log10": float(rmse_log),
"coverage_95_percent": float(coverage_95),
}
# Safety classification accuracy
safety_acc = np.mean(safety_pred == safety_true) * 100.0
metrics["safety_accuracy_percent"] = float(safety_acc)
return metrics
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate PI-ResMLP ensemble")
parser.add_argument("--config", default="configs/training.yaml")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
with open(args.config) as f:
config = yaml.safe_load(f)
device = get_device()
# Load normalizer and model
checkpoint_dir = Path(config["output"]["checkpoint_dir"])
normalizer = LogTransformStandardizer.load(checkpoint_dir / "normalization_params.json")
with open(checkpoint_dir / "model_config.json") as f:
model_kwargs = json.load(f)
ensemble = DeepEnsemble.load(
checkpoint_dir / "model_ensemble",
num_members=config["model"]["num_ensemble_members"],
**model_kwargs,
)
# Create test dataloader
data_dir = Path(config["data"]["directory"])
_, _, test_loader = create_dataloaders(
data_dir, normalizer,
batch_size=config["data"]["batch_size"],
)
# Evaluate
metrics = evaluate_ensemble(ensemble, test_loader, device)
# Print results
logger.info("\n" + "=" * 60)
logger.info("EVALUATION RESULTS")
logger.info("=" * 60)
for key, value in metrics.items():
if isinstance(value, dict):
logger.info(f"\n{key.upper()}:")
for k, v in value.items():
logger.info(f" {k}: {v:.4f}")
else:
logger.info(f"{key}: {value:.4f}")
# Save results
results_dir = Path(config["output"].get("figures_dir", "artifacts/figures"))
results_dir.mkdir(parents=True, exist_ok=True)
with open(results_dir / "eval_results.json", "w") as f:
json.dump(metrics, f, indent=2)
logger.info(f"\nResults saved to {results_dir / 'eval_results.json'}")
if __name__ == "__main__":
main()