Spaces:
Sleeping
Sleeping
| """Evaluation pipeline for the structural analysis surrogate. | |
| Computes R2, MAPE, max error per problem family, calibration metrics, | |
| and generates comparison tables and plots. | |
| Usage: | |
| python -m src.training.evaluate --config configs/training.yaml | |
| """ | |
| import argparse | |
| import json | |
| import logging | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import yaml | |
| from sklearn.metrics import r2_score | |
| from src.models.ensemble import DeepEnsemble | |
| from src.models.normalization import LogTransformStandardizer | |
| from src.training.dataset import create_dataloaders | |
| from src.utils.device import get_device | |
| logger = logging.getLogger(__name__) | |
| def evaluate_ensemble( | |
| ensemble: DeepEnsemble, | |
| test_loader: torch.utils.data.DataLoader, | |
| device: torch.device, | |
| ) -> dict: | |
| """Run evaluation on test set. | |
| Returns: | |
| Dict with metrics per output (stress, deflection) and overall. | |
| """ | |
| ensemble = ensemble.to(device) | |
| ensemble.eval() | |
| all_stress_pred = [] | |
| all_stress_true = [] | |
| all_defl_pred = [] | |
| all_defl_true = [] | |
| all_stress_std = [] | |
| all_defl_std = [] | |
| all_safety_pred = [] | |
| all_safety_true = [] | |
| with torch.no_grad(): | |
| for X_batch, targets in test_loader: | |
| X_batch = X_batch.to(device) | |
| out = ensemble(X_batch) | |
| all_stress_pred.append(out["stress_mean"].cpu().numpy()) | |
| all_defl_pred.append(out["deflection_mean"].cpu().numpy()) | |
| all_stress_std.append(torch.sqrt(out["stress_var"]).cpu().numpy()) | |
| all_defl_std.append(torch.sqrt(out["deflection_var"]).cpu().numpy()) | |
| all_stress_true.append(targets["log_stress"].numpy()) | |
| all_defl_true.append(targets["log_deflection"].numpy()) | |
| all_safety_pred.append(out["safety"].argmax(dim=1).cpu().numpy()) | |
| all_safety_true.append(targets["safety_class"].numpy()) | |
| # Concatenate | |
| stress_pred = np.concatenate(all_stress_pred) | |
| stress_true = np.concatenate(all_stress_true) | |
| defl_pred = np.concatenate(all_defl_pred) | |
| defl_true = np.concatenate(all_defl_true) | |
| stress_std = np.concatenate(all_stress_std) | |
| defl_std = np.concatenate(all_defl_std) | |
| safety_pred = np.concatenate(all_safety_pred) | |
| safety_true = np.concatenate(all_safety_true) | |
| # Metrics in log-space (predictions are in log10) | |
| metrics = {} | |
| for name, pred, true, std in [ | |
| ("stress", stress_pred, stress_true, stress_std), | |
| ("deflection", defl_pred, defl_true, defl_std), | |
| ]: | |
| r2 = r2_score(true, pred) | |
| # MAPE in original space: |10^pred - 10^true| / 10^true * 100 | |
| pred_orig = 10.0 ** pred | |
| true_orig = 10.0 ** true | |
| mape = np.mean(np.abs(pred_orig - true_orig) / true_orig) * 100.0 | |
| # Max absolute percentage error | |
| max_ape = np.max(np.abs(pred_orig - true_orig) / true_orig) * 100.0 | |
| # RMSE in log-space | |
| rmse_log = np.sqrt(np.mean((pred - true) ** 2)) | |
| # Calibration: what fraction of test points fall within predicted 95% CI? | |
| z95 = 1.96 | |
| lower = pred - z95 * std | |
| upper = pred + z95 * std | |
| coverage_95 = np.mean((true >= lower) & (true <= upper)) * 100.0 | |
| metrics[name] = { | |
| "r2": float(r2), | |
| "mape_percent": float(mape), | |
| "max_ape_percent": float(max_ape), | |
| "rmse_log10": float(rmse_log), | |
| "coverage_95_percent": float(coverage_95), | |
| } | |
| # Safety classification accuracy | |
| safety_acc = np.mean(safety_pred == safety_true) * 100.0 | |
| metrics["safety_accuracy_percent"] = float(safety_acc) | |
| return metrics | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate PI-ResMLP ensemble") | |
| parser.add_argument("--config", default="configs/training.yaml") | |
| args = parser.parse_args() | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
| with open(args.config) as f: | |
| config = yaml.safe_load(f) | |
| device = get_device() | |
| # Load normalizer and model | |
| checkpoint_dir = Path(config["output"]["checkpoint_dir"]) | |
| normalizer = LogTransformStandardizer.load(checkpoint_dir / "normalization_params.json") | |
| with open(checkpoint_dir / "model_config.json") as f: | |
| model_kwargs = json.load(f) | |
| ensemble = DeepEnsemble.load( | |
| checkpoint_dir / "model_ensemble", | |
| num_members=config["model"]["num_ensemble_members"], | |
| **model_kwargs, | |
| ) | |
| # Create test dataloader | |
| data_dir = Path(config["data"]["directory"]) | |
| _, _, test_loader = create_dataloaders( | |
| data_dir, normalizer, | |
| batch_size=config["data"]["batch_size"], | |
| ) | |
| # Evaluate | |
| metrics = evaluate_ensemble(ensemble, test_loader, device) | |
| # Print results | |
| logger.info("\n" + "=" * 60) | |
| logger.info("EVALUATION RESULTS") | |
| logger.info("=" * 60) | |
| for key, value in metrics.items(): | |
| if isinstance(value, dict): | |
| logger.info(f"\n{key.upper()}:") | |
| for k, v in value.items(): | |
| logger.info(f" {k}: {v:.4f}") | |
| else: | |
| logger.info(f"{key}: {value:.4f}") | |
| # Save results | |
| results_dir = Path(config["output"].get("figures_dir", "artifacts/figures")) | |
| results_dir.mkdir(parents=True, exist_ok=True) | |
| with open(results_dir / "eval_results.json", "w") as f: | |
| json.dump(metrics, f, indent=2) | |
| logger.info(f"\nResults saved to {results_dir / 'eval_results.json'}") | |
| if __name__ == "__main__": | |
| main() | |