File size: 4,798 Bytes
8e5ba9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""Deep Ensemble for calibrated uncertainty quantification.

Uses independently initialized models (Lakshminarayanan et al., 2017) rather
than MC Dropout — deep ensembles outperform MC Dropout for calibrated uncertainty.
Each member predicts mean and log-variance (heteroscedastic regression).
Final prediction is a mixture of Gaussians from all members.
"""

from pathlib import Path
from typing import Optional

import numpy as np
import torch
import torch.nn as nn

from src.models.architecture import PIResMLP


class DeepEnsemble(nn.Module):
    """Ensemble of PIResMLP models for uncertainty quantification.

    Prediction: mixture of Gaussians from N independently trained members.
    Mean = average of member means.
    Variance = average of (member_var + member_mean^2) - ensemble_mean^2
    (law of total variance).
    """

    def __init__(
        self,
        num_members: int = 5,
        **model_kwargs: dict,
    ) -> None:
        super().__init__()
        self.num_members = num_members
        self.members = nn.ModuleList([
            PIResMLP(**model_kwargs) for _ in range(num_members)
        ])

    def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
        """Run all ensemble members and aggregate.

        Returns:
            Dict with:
                - 'stress_mean': (batch,) ensemble mean prediction
                - 'stress_var': (batch,) total variance (epistemic + aleatoric)
                - 'deflection_mean': (batch,)
                - 'deflection_var': (batch,)
                - 'safety': (batch, 3) averaged softmax probabilities
                - 'member_outputs': list of individual member outputs
        """
        member_outputs = [member(x) for member in self.members]

        result = {}
        for key in ["stress", "deflection"]:
            means = torch.stack([out[key][:, 0] for out in member_outputs])  # (M, batch)
            log_vars = torch.stack([out[key][:, 1] for out in member_outputs])
            log_vars = torch.clamp(log_vars, min=-10.0, max=10.0)
            vars_ = torch.exp(log_vars)  # (M, batch)

            # Ensemble mean
            ensemble_mean = means.mean(dim=0)  # (batch,)

            # Total variance via law of total variance:
            # Var = E[Var_i] + Var[Mean_i]
            aleatoric = vars_.mean(dim=0)  # E[Var_i]
            epistemic = means.var(dim=0)   # Var[Mean_i]
            total_var = aleatoric + epistemic

            result[f"{key}_mean"] = ensemble_mean
            result[f"{key}_var"] = total_var
            result[f"{key}_aleatoric"] = aleatoric
            result[f"{key}_epistemic"] = epistemic

        # Safety: average softmax probabilities
        safety_probs = torch.stack([
            torch.softmax(out["safety"], dim=1) for out in member_outputs
        ])
        result["safety"] = safety_probs.mean(dim=0)
        result["member_outputs"] = member_outputs

        return result

    def predict_with_uncertainty(
        self,
        x: torch.Tensor,
        confidence: float = 0.95,
    ) -> dict[str, torch.Tensor]:
        """Predict with confidence intervals.

        Args:
            x: Input tensor.
            confidence: Confidence level for prediction interval (default 95%).

        Returns:
            Dict with mean, lower, upper bounds for stress and deflection.
        """
        self.eval()
        with torch.no_grad():
            out = self.forward(x)

        # z-score for confidence interval (Gaussian approximation)
        from scipy.stats import norm
        z = norm.ppf(0.5 + confidence / 2)

        result = {}
        for key in ["stress", "deflection"]:
            mean = out[f"{key}_mean"]
            std = torch.sqrt(out[f"{key}_var"])
            result[f"{key}_mean"] = mean
            result[f"{key}_lower"] = mean - z * std
            result[f"{key}_upper"] = mean + z * std
            result[f"{key}_std"] = std

        result["safety_probs"] = out["safety"]
        result["safety_class"] = out["safety"].argmax(dim=1)

        return result

    def save(self, directory: Path) -> None:
        """Save each ensemble member as a separate file."""
        directory.mkdir(parents=True, exist_ok=True)
        for i, member in enumerate(self.members):
            torch.save(member.state_dict(), directory / f"member_{i}.pt")

    @classmethod
    def load(cls, directory: Path, num_members: int = 5, **model_kwargs: dict) -> "DeepEnsemble":
        """Load ensemble from directory of member checkpoints."""
        ensemble = cls(num_members=num_members, **model_kwargs)
        for i, member in enumerate(ensemble.members):
            path = directory / f"member_{i}.pt"
            member.load_state_dict(torch.load(path, map_location="cpu", weights_only=True))
        return ensemble