Spaces:
Sleeping
Sleeping
| """Deep Ensemble for calibrated uncertainty quantification. | |
| Uses independently initialized models (Lakshminarayanan et al., 2017) rather | |
| than MC Dropout — deep ensembles outperform MC Dropout for calibrated uncertainty. | |
| Each member predicts mean and log-variance (heteroscedastic regression). | |
| Final prediction is a mixture of Gaussians from all members. | |
| """ | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from src.models.architecture import PIResMLP | |
| class DeepEnsemble(nn.Module): | |
| """Ensemble of PIResMLP models for uncertainty quantification. | |
| Prediction: mixture of Gaussians from N independently trained members. | |
| Mean = average of member means. | |
| Variance = average of (member_var + member_mean^2) - ensemble_mean^2 | |
| (law of total variance). | |
| """ | |
| def __init__( | |
| self, | |
| num_members: int = 5, | |
| **model_kwargs: dict, | |
| ) -> None: | |
| super().__init__() | |
| self.num_members = num_members | |
| self.members = nn.ModuleList([ | |
| PIResMLP(**model_kwargs) for _ in range(num_members) | |
| ]) | |
| def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: | |
| """Run all ensemble members and aggregate. | |
| Returns: | |
| Dict with: | |
| - 'stress_mean': (batch,) ensemble mean prediction | |
| - 'stress_var': (batch,) total variance (epistemic + aleatoric) | |
| - 'deflection_mean': (batch,) | |
| - 'deflection_var': (batch,) | |
| - 'safety': (batch, 3) averaged softmax probabilities | |
| - 'member_outputs': list of individual member outputs | |
| """ | |
| member_outputs = [member(x) for member in self.members] | |
| result = {} | |
| for key in ["stress", "deflection"]: | |
| means = torch.stack([out[key][:, 0] for out in member_outputs]) # (M, batch) | |
| log_vars = torch.stack([out[key][:, 1] for out in member_outputs]) | |
| log_vars = torch.clamp(log_vars, min=-10.0, max=10.0) | |
| vars_ = torch.exp(log_vars) # (M, batch) | |
| # Ensemble mean | |
| ensemble_mean = means.mean(dim=0) # (batch,) | |
| # Total variance via law of total variance: | |
| # Var = E[Var_i] + Var[Mean_i] | |
| aleatoric = vars_.mean(dim=0) # E[Var_i] | |
| epistemic = means.var(dim=0) # Var[Mean_i] | |
| total_var = aleatoric + epistemic | |
| result[f"{key}_mean"] = ensemble_mean | |
| result[f"{key}_var"] = total_var | |
| result[f"{key}_aleatoric"] = aleatoric | |
| result[f"{key}_epistemic"] = epistemic | |
| # Safety: average softmax probabilities | |
| safety_probs = torch.stack([ | |
| torch.softmax(out["safety"], dim=1) for out in member_outputs | |
| ]) | |
| result["safety"] = safety_probs.mean(dim=0) | |
| result["member_outputs"] = member_outputs | |
| return result | |
| def predict_with_uncertainty( | |
| self, | |
| x: torch.Tensor, | |
| confidence: float = 0.95, | |
| ) -> dict[str, torch.Tensor]: | |
| """Predict with confidence intervals. | |
| Args: | |
| x: Input tensor. | |
| confidence: Confidence level for prediction interval (default 95%). | |
| Returns: | |
| Dict with mean, lower, upper bounds for stress and deflection. | |
| """ | |
| self.eval() | |
| with torch.no_grad(): | |
| out = self.forward(x) | |
| # z-score for confidence interval (Gaussian approximation) | |
| from scipy.stats import norm | |
| z = norm.ppf(0.5 + confidence / 2) | |
| result = {} | |
| for key in ["stress", "deflection"]: | |
| mean = out[f"{key}_mean"] | |
| std = torch.sqrt(out[f"{key}_var"]) | |
| result[f"{key}_mean"] = mean | |
| result[f"{key}_lower"] = mean - z * std | |
| result[f"{key}_upper"] = mean + z * std | |
| result[f"{key}_std"] = std | |
| result["safety_probs"] = out["safety"] | |
| result["safety_class"] = out["safety"].argmax(dim=1) | |
| return result | |
| def save(self, directory: Path) -> None: | |
| """Save each ensemble member as a separate file.""" | |
| directory.mkdir(parents=True, exist_ok=True) | |
| for i, member in enumerate(self.members): | |
| torch.save(member.state_dict(), directory / f"member_{i}.pt") | |
| def load(cls, directory: Path, num_members: int = 5, **model_kwargs: dict) -> "DeepEnsemble": | |
| """Load ensemble from directory of member checkpoints.""" | |
| ensemble = cls(num_members=num_members, **model_kwargs) | |
| for i, member in enumerate(ensemble.members): | |
| path = directory / f"member_{i}.pt" | |
| member.load_state_dict(torch.load(path, map_location="cpu", weights_only=True)) | |
| return ensemble | |