"""Evaluation pipeline for the structural analysis surrogate. Computes R2, MAPE, max error per problem family, calibration metrics, and generates comparison tables and plots. Usage: python -m src.training.evaluate --config configs/training.yaml """ import argparse import json import logging from pathlib import Path import numpy as np import torch import yaml from sklearn.metrics import r2_score from src.models.ensemble import DeepEnsemble from src.models.normalization import LogTransformStandardizer from src.training.dataset import create_dataloaders from src.utils.device import get_device logger = logging.getLogger(__name__) def evaluate_ensemble( ensemble: DeepEnsemble, test_loader: torch.utils.data.DataLoader, device: torch.device, ) -> dict: """Run evaluation on test set. Returns: Dict with metrics per output (stress, deflection) and overall. """ ensemble = ensemble.to(device) ensemble.eval() all_stress_pred = [] all_stress_true = [] all_defl_pred = [] all_defl_true = [] all_stress_std = [] all_defl_std = [] all_safety_pred = [] all_safety_true = [] with torch.no_grad(): for X_batch, targets in test_loader: X_batch = X_batch.to(device) out = ensemble(X_batch) all_stress_pred.append(out["stress_mean"].cpu().numpy()) all_defl_pred.append(out["deflection_mean"].cpu().numpy()) all_stress_std.append(torch.sqrt(out["stress_var"]).cpu().numpy()) all_defl_std.append(torch.sqrt(out["deflection_var"]).cpu().numpy()) all_stress_true.append(targets["log_stress"].numpy()) all_defl_true.append(targets["log_deflection"].numpy()) all_safety_pred.append(out["safety"].argmax(dim=1).cpu().numpy()) all_safety_true.append(targets["safety_class"].numpy()) # Concatenate stress_pred = np.concatenate(all_stress_pred) stress_true = np.concatenate(all_stress_true) defl_pred = np.concatenate(all_defl_pred) defl_true = np.concatenate(all_defl_true) stress_std = np.concatenate(all_stress_std) defl_std = np.concatenate(all_defl_std) safety_pred = np.concatenate(all_safety_pred) safety_true = np.concatenate(all_safety_true) # Metrics in log-space (predictions are in log10) metrics = {} for name, pred, true, std in [ ("stress", stress_pred, stress_true, stress_std), ("deflection", defl_pred, defl_true, defl_std), ]: r2 = r2_score(true, pred) # MAPE in original space: |10^pred - 10^true| / 10^true * 100 pred_orig = 10.0 ** pred true_orig = 10.0 ** true mape = np.mean(np.abs(pred_orig - true_orig) / true_orig) * 100.0 # Max absolute percentage error max_ape = np.max(np.abs(pred_orig - true_orig) / true_orig) * 100.0 # RMSE in log-space rmse_log = np.sqrt(np.mean((pred - true) ** 2)) # Calibration: what fraction of test points fall within predicted 95% CI? z95 = 1.96 lower = pred - z95 * std upper = pred + z95 * std coverage_95 = np.mean((true >= lower) & (true <= upper)) * 100.0 metrics[name] = { "r2": float(r2), "mape_percent": float(mape), "max_ape_percent": float(max_ape), "rmse_log10": float(rmse_log), "coverage_95_percent": float(coverage_95), } # Safety classification accuracy safety_acc = np.mean(safety_pred == safety_true) * 100.0 metrics["safety_accuracy_percent"] = float(safety_acc) return metrics def main() -> None: parser = argparse.ArgumentParser(description="Evaluate PI-ResMLP ensemble") parser.add_argument("--config", default="configs/training.yaml") args = parser.parse_args() logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") with open(args.config) as f: config = yaml.safe_load(f) device = get_device() # Load normalizer and model checkpoint_dir = Path(config["output"]["checkpoint_dir"]) normalizer = LogTransformStandardizer.load(checkpoint_dir / "normalization_params.json") with open(checkpoint_dir / "model_config.json") as f: model_kwargs = json.load(f) ensemble = DeepEnsemble.load( checkpoint_dir / "model_ensemble", num_members=config["model"]["num_ensemble_members"], **model_kwargs, ) # Create test dataloader data_dir = Path(config["data"]["directory"]) _, _, test_loader = create_dataloaders( data_dir, normalizer, batch_size=config["data"]["batch_size"], ) # Evaluate metrics = evaluate_ensemble(ensemble, test_loader, device) # Print results logger.info("\n" + "=" * 60) logger.info("EVALUATION RESULTS") logger.info("=" * 60) for key, value in metrics.items(): if isinstance(value, dict): logger.info(f"\n{key.upper()}:") for k, v in value.items(): logger.info(f" {k}: {v:.4f}") else: logger.info(f"{key}: {value:.4f}") # Save results results_dir = Path(config["output"].get("figures_dir", "artifacts/figures")) results_dir.mkdir(parents=True, exist_ok=True) with open(results_dir / "eval_results.json", "w") as f: json.dump(metrics, f, indent=2) logger.info(f"\nResults saved to {results_dir / 'eval_results.json'}") if __name__ == "__main__": main()