fea-surrogate / src /models /ensemble.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Deep Ensemble for calibrated uncertainty quantification.
Uses independently initialized models (Lakshminarayanan et al., 2017) rather
than MC Dropout — deep ensembles outperform MC Dropout for calibrated uncertainty.
Each member predicts mean and log-variance (heteroscedastic regression).
Final prediction is a mixture of Gaussians from all members.
"""
from pathlib import Path
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from src.models.architecture import PIResMLP
class DeepEnsemble(nn.Module):
"""Ensemble of PIResMLP models for uncertainty quantification.
Prediction: mixture of Gaussians from N independently trained members.
Mean = average of member means.
Variance = average of (member_var + member_mean^2) - ensemble_mean^2
(law of total variance).
"""
def __init__(
self,
num_members: int = 5,
**model_kwargs: dict,
) -> None:
super().__init__()
self.num_members = num_members
self.members = nn.ModuleList([
PIResMLP(**model_kwargs) for _ in range(num_members)
])
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""Run all ensemble members and aggregate.
Returns:
Dict with:
- 'stress_mean': (batch,) ensemble mean prediction
- 'stress_var': (batch,) total variance (epistemic + aleatoric)
- 'deflection_mean': (batch,)
- 'deflection_var': (batch,)
- 'safety': (batch, 3) averaged softmax probabilities
- 'member_outputs': list of individual member outputs
"""
member_outputs = [member(x) for member in self.members]
result = {}
for key in ["stress", "deflection"]:
means = torch.stack([out[key][:, 0] for out in member_outputs]) # (M, batch)
log_vars = torch.stack([out[key][:, 1] for out in member_outputs])
log_vars = torch.clamp(log_vars, min=-10.0, max=10.0)
vars_ = torch.exp(log_vars) # (M, batch)
# Ensemble mean
ensemble_mean = means.mean(dim=0) # (batch,)
# Total variance via law of total variance:
# Var = E[Var_i] + Var[Mean_i]
aleatoric = vars_.mean(dim=0) # E[Var_i]
epistemic = means.var(dim=0) # Var[Mean_i]
total_var = aleatoric + epistemic
result[f"{key}_mean"] = ensemble_mean
result[f"{key}_var"] = total_var
result[f"{key}_aleatoric"] = aleatoric
result[f"{key}_epistemic"] = epistemic
# Safety: average softmax probabilities
safety_probs = torch.stack([
torch.softmax(out["safety"], dim=1) for out in member_outputs
])
result["safety"] = safety_probs.mean(dim=0)
result["member_outputs"] = member_outputs
return result
def predict_with_uncertainty(
self,
x: torch.Tensor,
confidence: float = 0.95,
) -> dict[str, torch.Tensor]:
"""Predict with confidence intervals.
Args:
x: Input tensor.
confidence: Confidence level for prediction interval (default 95%).
Returns:
Dict with mean, lower, upper bounds for stress and deflection.
"""
self.eval()
with torch.no_grad():
out = self.forward(x)
# z-score for confidence interval (Gaussian approximation)
from scipy.stats import norm
z = norm.ppf(0.5 + confidence / 2)
result = {}
for key in ["stress", "deflection"]:
mean = out[f"{key}_mean"]
std = torch.sqrt(out[f"{key}_var"])
result[f"{key}_mean"] = mean
result[f"{key}_lower"] = mean - z * std
result[f"{key}_upper"] = mean + z * std
result[f"{key}_std"] = std
result["safety_probs"] = out["safety"]
result["safety_class"] = out["safety"].argmax(dim=1)
return result
def save(self, directory: Path) -> None:
"""Save each ensemble member as a separate file."""
directory.mkdir(parents=True, exist_ok=True)
for i, member in enumerate(self.members):
torch.save(member.state_dict(), directory / f"member_{i}.pt")
@classmethod
def load(cls, directory: Path, num_members: int = 5, **model_kwargs: dict) -> "DeepEnsemble":
"""Load ensemble from directory of member checkpoints."""
ensemble = cls(num_members=num_members, **model_kwargs)
for i, member in enumerate(ensemble.members):
path = directory / f"member_{i}.pt"
member.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
return ensemble