"""Deep Ensemble for calibrated uncertainty quantification. Uses independently initialized models (Lakshminarayanan et al., 2017) rather than MC Dropout — deep ensembles outperform MC Dropout for calibrated uncertainty. Each member predicts mean and log-variance (heteroscedastic regression). Final prediction is a mixture of Gaussians from all members. """ from pathlib import Path from typing import Optional import numpy as np import torch import torch.nn as nn from src.models.architecture import PIResMLP class DeepEnsemble(nn.Module): """Ensemble of PIResMLP models for uncertainty quantification. Prediction: mixture of Gaussians from N independently trained members. Mean = average of member means. Variance = average of (member_var + member_mean^2) - ensemble_mean^2 (law of total variance). """ def __init__( self, num_members: int = 5, **model_kwargs: dict, ) -> None: super().__init__() self.num_members = num_members self.members = nn.ModuleList([ PIResMLP(**model_kwargs) for _ in range(num_members) ]) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """Run all ensemble members and aggregate. Returns: Dict with: - 'stress_mean': (batch,) ensemble mean prediction - 'stress_var': (batch,) total variance (epistemic + aleatoric) - 'deflection_mean': (batch,) - 'deflection_var': (batch,) - 'safety': (batch, 3) averaged softmax probabilities - 'member_outputs': list of individual member outputs """ member_outputs = [member(x) for member in self.members] result = {} for key in ["stress", "deflection"]: means = torch.stack([out[key][:, 0] for out in member_outputs]) # (M, batch) log_vars = torch.stack([out[key][:, 1] for out in member_outputs]) log_vars = torch.clamp(log_vars, min=-10.0, max=10.0) vars_ = torch.exp(log_vars) # (M, batch) # Ensemble mean ensemble_mean = means.mean(dim=0) # (batch,) # Total variance via law of total variance: # Var = E[Var_i] + Var[Mean_i] aleatoric = vars_.mean(dim=0) # E[Var_i] epistemic = means.var(dim=0) # Var[Mean_i] total_var = aleatoric + epistemic result[f"{key}_mean"] = ensemble_mean result[f"{key}_var"] = total_var result[f"{key}_aleatoric"] = aleatoric result[f"{key}_epistemic"] = epistemic # Safety: average softmax probabilities safety_probs = torch.stack([ torch.softmax(out["safety"], dim=1) for out in member_outputs ]) result["safety"] = safety_probs.mean(dim=0) result["member_outputs"] = member_outputs return result def predict_with_uncertainty( self, x: torch.Tensor, confidence: float = 0.95, ) -> dict[str, torch.Tensor]: """Predict with confidence intervals. Args: x: Input tensor. confidence: Confidence level for prediction interval (default 95%). Returns: Dict with mean, lower, upper bounds for stress and deflection. """ self.eval() with torch.no_grad(): out = self.forward(x) # z-score for confidence interval (Gaussian approximation) from scipy.stats import norm z = norm.ppf(0.5 + confidence / 2) result = {} for key in ["stress", "deflection"]: mean = out[f"{key}_mean"] std = torch.sqrt(out[f"{key}_var"]) result[f"{key}_mean"] = mean result[f"{key}_lower"] = mean - z * std result[f"{key}_upper"] = mean + z * std result[f"{key}_std"] = std result["safety_probs"] = out["safety"] result["safety_class"] = out["safety"].argmax(dim=1) return result def save(self, directory: Path) -> None: """Save each ensemble member as a separate file.""" directory.mkdir(parents=True, exist_ok=True) for i, member in enumerate(self.members): torch.save(member.state_dict(), directory / f"member_{i}.pt") @classmethod def load(cls, directory: Path, num_members: int = 5, **model_kwargs: dict) -> "DeepEnsemble": """Load ensemble from directory of member checkpoints.""" ensemble = cls(num_members=num_members, **model_kwargs) for i, member in enumerate(ensemble.members): path = directory / f"member_{i}.pt" member.load_state_dict(torch.load(path, map_location="cpu", weights_only=True)) return ensemble