"""Physics-Informed Residual MLP (PI-ResMLP) for structural analysis surrogate. Architecture: Input (N features) -> Linear(N, 256) -> LayerNorm -> SiLU -> Dropout -> ResidualBlock(256) x 4 -> Linear(256, 128) -> LayerNorm -> SiLU -> OutputHeads: stress_head: Linear(128, 2) [mean, log_var] (heteroscedastic) deflection_head: Linear(128, 2) [mean, log_var] safety_head: Linear(128, 3) [safe, marginal, failure] Why MLP over Transformer: Tabular regression on 15-20 numeric features does not benefit from attention mechanisms. Using a transformer here would signal cargo-cult thinking. (See: Grinsztajn et al., 2022) Why Residual connections: Prevents vanishing gradients in deeper networks and allows the model to learn identity mappings where features already predict the output well (e.g., load directly determines stress direction). """ import torch import torch.nn as nn class ResidualBlock(nn.Module): """Pre-activation residual block with LayerNorm. x -> Linear -> LayerNorm -> SiLU -> Dropout -> Linear -> + x """ def __init__(self, dim: int, dropout: float = 0.1) -> None: super().__init__() self.block = nn.Sequential( nn.Linear(dim, dim), nn.LayerNorm(dim), nn.SiLU(), nn.Dropout(dropout), nn.Linear(dim, dim), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return x + self.block(x) class PIResMLP(nn.Module): """Physics-Informed Residual MLP with multi-task output heads. Predicts in log-space for stress and deflection (quantities span Pa to GPa). Heteroscedastic outputs predict both mean and log-variance for calibrated uncertainty estimation via deep ensembles. """ def __init__( self, input_dim: int, hidden_dim: int = 256, num_blocks: int = 4, dropout: float = 0.1, num_classes: int = 3, heteroscedastic: bool = True, ) -> None: super().__init__() self.heteroscedastic = heteroscedastic # Stem: project input to hidden dimension self.stem = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.SiLU(), nn.Dropout(dropout), ) # Residual backbone self.backbone = nn.Sequential( *[ResidualBlock(hidden_dim, dropout) for _ in range(num_blocks)] ) # Neck: compress to output dimension self.neck = nn.Sequential( nn.Linear(hidden_dim, hidden_dim // 2), nn.LayerNorm(hidden_dim // 2), nn.SiLU(), ) neck_dim = hidden_dim // 2 out_per_head = 2 if heteroscedastic else 1 # Output heads self.stress_head = nn.Linear(neck_dim, out_per_head) self.deflection_head = nn.Linear(neck_dim, out_per_head) self.safety_head = nn.Linear(neck_dim, num_classes) def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: """Forward pass. Args: x: Input tensor of shape (batch, input_dim). Returns: Dictionary with keys: - 'stress': (batch, 2) or (batch, 1) — [mean, log_var] or [mean] - 'deflection': (batch, 2) or (batch, 1) - 'safety': (batch, 3) — logits for [safe, marginal, failure] """ h = self.stem(x) h = self.backbone(h) h = self.neck(h) return { "stress": self.stress_head(h), "deflection": self.deflection_head(h), "safety": self.safety_head(h), }