"""Physics-informed composite loss function. L_total = L_regression + lambda_cls * L_classification + lambda_phys * L_physics The physics penalty encodes known physical relationships without needing exact solutions, improving generalization and extrapolation: 1. Monotonicity: stress increases with load magnitude 2. Energy bound: deflection must be non-negative 3. Safety consistency: regression-derived SF category must match classification head """ import torch import torch.nn as nn import torch.nn.functional as F class PhysicsInformedLoss(nn.Module): """Composite loss with physics-informed regularization. Supports heteroscedastic regression (predicting mean + log_variance) via negative log-likelihood, which naturally calibrates uncertainty. """ def __init__( self, classification_weight: float = 0.3, physics_weight: float = 0.1, heteroscedastic: bool = True, ) -> None: super().__init__() self.cls_weight = classification_weight self.phys_weight = physics_weight self.heteroscedastic = heteroscedastic self.ce_loss = nn.CrossEntropyLoss() def _regression_loss( self, pred: torch.Tensor, target: torch.Tensor, ) -> torch.Tensor: """Heteroscedastic NLL or plain MSE. For heteroscedastic: pred is (batch, 2) with [mean, log_var]. NLL = 0.5 * [log(var) + (y - mu)^2 / var] """ if self.heteroscedastic: mu = pred[:, 0] log_var = pred[:, 1] # Clamp log_var for numerical stability log_var = torch.clamp(log_var, min=-10.0, max=10.0) var = torch.exp(log_var) nll = 0.5 * (log_var + (target - mu) ** 2 / var) return nll.mean() else: return F.mse_loss(pred.squeeze(-1), target) def _physics_penalty( self, stress_pred: torch.Tensor, deflection_pred: torch.Tensor, safety_logits: torch.Tensor, targets: dict[str, torch.Tensor], ) -> torch.Tensor: """Physics-informed regularization penalties. 1. Energy bound: predicted deflection mean should be non-negative (we predict in log-space, so this is about the mean value) 2. Safety consistency: regression-derived category should match classification head prediction """ penalty = torch.tensor(0.0, device=stress_pred.device) # Get predicted means stress_mu = stress_pred[:, 0] if self.heteroscedastic else stress_pred.squeeze(-1) defl_mu = deflection_pred[:, 0] if self.heteroscedastic else deflection_pred.squeeze(-1) # 1. Energy bound: deflection should be non-negative # In log-space, any real value is valid, but we penalize extreme negatives # that would correspond to unphysically small deflections energy_penalty = F.relu(-defl_mu - 20.0).mean() # penalize log10(defl) < -20 penalty = penalty + energy_penalty # 2. Safety consistency: derive category from regression and compare # safety_factor = 10^(log_yield - log_stress) if "log_yield_strength" in targets: log_sf = targets["log_yield_strength"] - stress_mu # Derive expected class: SF>=log10(2)→safe, SF>=0→marginal, else→failure log2 = 0.30103 # log10(2) derived_safe = (log_sf >= log2).float() derived_marginal = ((log_sf >= 0) & (log_sf < log2)).float() derived_failure = (log_sf < 0).float() derived_probs = torch.stack([derived_safe, derived_marginal, derived_failure], dim=1) # KL divergence between derived distribution and predicted pred_probs = F.softmax(safety_logits, dim=1) consistency = F.kl_div( pred_probs.log().clamp(min=-100), derived_probs, reduction="batchmean", ) penalty = penalty + consistency return penalty def forward( self, predictions: dict[str, torch.Tensor], targets: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """Compute total loss with breakdown. Args: predictions: Model output dict with 'stress', 'deflection', 'safety' keys. targets: Dict with 'log_stress', 'log_deflection', 'safety_class', and optionally 'log_yield_strength'. Returns: Dict with 'total', 'regression', 'classification', 'physics' losses. """ # Regression losses stress_loss = self._regression_loss(predictions["stress"], targets["log_stress"]) defl_loss = self._regression_loss(predictions["deflection"], targets["log_deflection"]) regression_loss = stress_loss + defl_loss # Classification loss cls_loss = self.ce_loss(predictions["safety"], targets["safety_class"]) # Physics penalty phys_loss = self._physics_penalty( predictions["stress"], predictions["deflection"], predictions["safety"], targets, ) total = regression_loss + self.cls_weight * cls_loss + self.phys_weight * phys_loss return { "total": total, "regression": regression_loss, "classification": cls_loss, "physics": phys_loss, }