fea-surrogate / src /models /architecture.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Physics-Informed Residual MLP (PI-ResMLP) for structural analysis surrogate.
Architecture:
Input (N features) -> Linear(N, 256) -> LayerNorm -> SiLU -> Dropout
-> ResidualBlock(256) x 4
-> Linear(256, 128) -> LayerNorm -> SiLU
-> OutputHeads:
stress_head: Linear(128, 2) [mean, log_var] (heteroscedastic)
deflection_head: Linear(128, 2) [mean, log_var]
safety_head: Linear(128, 3) [safe, marginal, failure]
Why MLP over Transformer: Tabular regression on 15-20 numeric features does
not benefit from attention mechanisms. Using a transformer here would signal
cargo-cult thinking. (See: Grinsztajn et al., 2022)
Why Residual connections: Prevents vanishing gradients in deeper networks and
allows the model to learn identity mappings where features already predict
the output well (e.g., load directly determines stress direction).
"""
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
"""Pre-activation residual block with LayerNorm.
x -> Linear -> LayerNorm -> SiLU -> Dropout -> Linear -> + x
"""
def __init__(self, dim: int, dropout: float = 0.1) -> None:
super().__init__()
self.block = nn.Sequential(
nn.Linear(dim, dim),
nn.LayerNorm(dim),
nn.SiLU(),
nn.Dropout(dropout),
nn.Linear(dim, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class PIResMLP(nn.Module):
"""Physics-Informed Residual MLP with multi-task output heads.
Predicts in log-space for stress and deflection (quantities span Pa to GPa).
Heteroscedastic outputs predict both mean and log-variance for calibrated
uncertainty estimation via deep ensembles.
"""
def __init__(
self,
input_dim: int,
hidden_dim: int = 256,
num_blocks: int = 4,
dropout: float = 0.1,
num_classes: int = 3,
heteroscedastic: bool = True,
) -> None:
super().__init__()
self.heteroscedastic = heteroscedastic
# Stem: project input to hidden dimension
self.stem = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.SiLU(),
nn.Dropout(dropout),
)
# Residual backbone
self.backbone = nn.Sequential(
*[ResidualBlock(hidden_dim, dropout) for _ in range(num_blocks)]
)
# Neck: compress to output dimension
self.neck = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.LayerNorm(hidden_dim // 2),
nn.SiLU(),
)
neck_dim = hidden_dim // 2
out_per_head = 2 if heteroscedastic else 1
# Output heads
self.stress_head = nn.Linear(neck_dim, out_per_head)
self.deflection_head = nn.Linear(neck_dim, out_per_head)
self.safety_head = nn.Linear(neck_dim, num_classes)
def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass.
Args:
x: Input tensor of shape (batch, input_dim).
Returns:
Dictionary with keys:
- 'stress': (batch, 2) or (batch, 1) — [mean, log_var] or [mean]
- 'deflection': (batch, 2) or (batch, 1)
- 'safety': (batch, 3) — logits for [safe, marginal, failure]
"""
h = self.stem(x)
h = self.backbone(h)
h = self.neck(h)
return {
"stress": self.stress_head(h),
"deflection": self.deflection_head(h),
"safety": self.safety_head(h),
}