fea-surrogate / src /models /physics_loss.py
WolfDavid's picture
Upload folder using huggingface_hub
8e5ba9e verified
"""Physics-informed composite loss function.
L_total = L_regression + lambda_cls * L_classification + lambda_phys * L_physics
The physics penalty encodes known physical relationships without needing
exact solutions, improving generalization and extrapolation:
1. Monotonicity: stress increases with load magnitude
2. Energy bound: deflection must be non-negative
3. Safety consistency: regression-derived SF category must match classification head
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhysicsInformedLoss(nn.Module):
"""Composite loss with physics-informed regularization.
Supports heteroscedastic regression (predicting mean + log_variance)
via negative log-likelihood, which naturally calibrates uncertainty.
"""
def __init__(
self,
classification_weight: float = 0.3,
physics_weight: float = 0.1,
heteroscedastic: bool = True,
) -> None:
super().__init__()
self.cls_weight = classification_weight
self.phys_weight = physics_weight
self.heteroscedastic = heteroscedastic
self.ce_loss = nn.CrossEntropyLoss()
def _regression_loss(
self,
pred: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
"""Heteroscedastic NLL or plain MSE.
For heteroscedastic: pred is (batch, 2) with [mean, log_var].
NLL = 0.5 * [log(var) + (y - mu)^2 / var]
"""
if self.heteroscedastic:
mu = pred[:, 0]
log_var = pred[:, 1]
# Clamp log_var for numerical stability
log_var = torch.clamp(log_var, min=-10.0, max=10.0)
var = torch.exp(log_var)
nll = 0.5 * (log_var + (target - mu) ** 2 / var)
return nll.mean()
else:
return F.mse_loss(pred.squeeze(-1), target)
def _physics_penalty(
self,
stress_pred: torch.Tensor,
deflection_pred: torch.Tensor,
safety_logits: torch.Tensor,
targets: dict[str, torch.Tensor],
) -> torch.Tensor:
"""Physics-informed regularization penalties.
1. Energy bound: predicted deflection mean should be non-negative
(we predict in log-space, so this is about the mean value)
2. Safety consistency: regression-derived category should match
classification head prediction
"""
penalty = torch.tensor(0.0, device=stress_pred.device)
# Get predicted means
stress_mu = stress_pred[:, 0] if self.heteroscedastic else stress_pred.squeeze(-1)
defl_mu = deflection_pred[:, 0] if self.heteroscedastic else deflection_pred.squeeze(-1)
# 1. Energy bound: deflection should be non-negative
# In log-space, any real value is valid, but we penalize extreme negatives
# that would correspond to unphysically small deflections
energy_penalty = F.relu(-defl_mu - 20.0).mean() # penalize log10(defl) < -20
penalty = penalty + energy_penalty
# 2. Safety consistency: derive category from regression and compare
# safety_factor = 10^(log_yield - log_stress)
if "log_yield_strength" in targets:
log_sf = targets["log_yield_strength"] - stress_mu
# Derive expected class: SF>=log10(2)→safe, SF>=0→marginal, else→failure
log2 = 0.30103 # log10(2)
derived_safe = (log_sf >= log2).float()
derived_marginal = ((log_sf >= 0) & (log_sf < log2)).float()
derived_failure = (log_sf < 0).float()
derived_probs = torch.stack([derived_safe, derived_marginal, derived_failure], dim=1)
# KL divergence between derived distribution and predicted
pred_probs = F.softmax(safety_logits, dim=1)
consistency = F.kl_div(
pred_probs.log().clamp(min=-100),
derived_probs,
reduction="batchmean",
)
penalty = penalty + consistency
return penalty
def forward(
self,
predictions: dict[str, torch.Tensor],
targets: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""Compute total loss with breakdown.
Args:
predictions: Model output dict with 'stress', 'deflection', 'safety' keys.
targets: Dict with 'log_stress', 'log_deflection', 'safety_class',
and optionally 'log_yield_strength'.
Returns:
Dict with 'total', 'regression', 'classification', 'physics' losses.
"""
# Regression losses
stress_loss = self._regression_loss(predictions["stress"], targets["log_stress"])
defl_loss = self._regression_loss(predictions["deflection"], targets["log_deflection"])
regression_loss = stress_loss + defl_loss
# Classification loss
cls_loss = self.ce_loss(predictions["safety"], targets["safety_class"])
# Physics penalty
phys_loss = self._physics_penalty(
predictions["stress"],
predictions["deflection"],
predictions["safety"],
targets,
)
total = regression_loss + self.cls_weight * cls_loss + self.phys_weight * phys_loss
return {
"total": total,
"regression": regression_loss,
"classification": cls_loss,
"physics": phys_loss,
}