Spaces:
Sleeping
Sleeping
| """Physics-informed composite loss function. | |
| L_total = L_regression + lambda_cls * L_classification + lambda_phys * L_physics | |
| The physics penalty encodes known physical relationships without needing | |
| exact solutions, improving generalization and extrapolation: | |
| 1. Monotonicity: stress increases with load magnitude | |
| 2. Energy bound: deflection must be non-negative | |
| 3. Safety consistency: regression-derived SF category must match classification head | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class PhysicsInformedLoss(nn.Module): | |
| """Composite loss with physics-informed regularization. | |
| Supports heteroscedastic regression (predicting mean + log_variance) | |
| via negative log-likelihood, which naturally calibrates uncertainty. | |
| """ | |
| def __init__( | |
| self, | |
| classification_weight: float = 0.3, | |
| physics_weight: float = 0.1, | |
| heteroscedastic: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.cls_weight = classification_weight | |
| self.phys_weight = physics_weight | |
| self.heteroscedastic = heteroscedastic | |
| self.ce_loss = nn.CrossEntropyLoss() | |
| def _regression_loss( | |
| self, | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """Heteroscedastic NLL or plain MSE. | |
| For heteroscedastic: pred is (batch, 2) with [mean, log_var]. | |
| NLL = 0.5 * [log(var) + (y - mu)^2 / var] | |
| """ | |
| if self.heteroscedastic: | |
| mu = pred[:, 0] | |
| log_var = pred[:, 1] | |
| # Clamp log_var for numerical stability | |
| log_var = torch.clamp(log_var, min=-10.0, max=10.0) | |
| var = torch.exp(log_var) | |
| nll = 0.5 * (log_var + (target - mu) ** 2 / var) | |
| return nll.mean() | |
| else: | |
| return F.mse_loss(pred.squeeze(-1), target) | |
| def _physics_penalty( | |
| self, | |
| stress_pred: torch.Tensor, | |
| deflection_pred: torch.Tensor, | |
| safety_logits: torch.Tensor, | |
| targets: dict[str, torch.Tensor], | |
| ) -> torch.Tensor: | |
| """Physics-informed regularization penalties. | |
| 1. Energy bound: predicted deflection mean should be non-negative | |
| (we predict in log-space, so this is about the mean value) | |
| 2. Safety consistency: regression-derived category should match | |
| classification head prediction | |
| """ | |
| penalty = torch.tensor(0.0, device=stress_pred.device) | |
| # Get predicted means | |
| stress_mu = stress_pred[:, 0] if self.heteroscedastic else stress_pred.squeeze(-1) | |
| defl_mu = deflection_pred[:, 0] if self.heteroscedastic else deflection_pred.squeeze(-1) | |
| # 1. Energy bound: deflection should be non-negative | |
| # In log-space, any real value is valid, but we penalize extreme negatives | |
| # that would correspond to unphysically small deflections | |
| energy_penalty = F.relu(-defl_mu - 20.0).mean() # penalize log10(defl) < -20 | |
| penalty = penalty + energy_penalty | |
| # 2. Safety consistency: derive category from regression and compare | |
| # safety_factor = 10^(log_yield - log_stress) | |
| if "log_yield_strength" in targets: | |
| log_sf = targets["log_yield_strength"] - stress_mu | |
| # Derive expected class: SF>=log10(2)→safe, SF>=0→marginal, else→failure | |
| log2 = 0.30103 # log10(2) | |
| derived_safe = (log_sf >= log2).float() | |
| derived_marginal = ((log_sf >= 0) & (log_sf < log2)).float() | |
| derived_failure = (log_sf < 0).float() | |
| derived_probs = torch.stack([derived_safe, derived_marginal, derived_failure], dim=1) | |
| # KL divergence between derived distribution and predicted | |
| pred_probs = F.softmax(safety_logits, dim=1) | |
| consistency = F.kl_div( | |
| pred_probs.log().clamp(min=-100), | |
| derived_probs, | |
| reduction="batchmean", | |
| ) | |
| penalty = penalty + consistency | |
| return penalty | |
| def forward( | |
| self, | |
| predictions: dict[str, torch.Tensor], | |
| targets: dict[str, torch.Tensor], | |
| ) -> dict[str, torch.Tensor]: | |
| """Compute total loss with breakdown. | |
| Args: | |
| predictions: Model output dict with 'stress', 'deflection', 'safety' keys. | |
| targets: Dict with 'log_stress', 'log_deflection', 'safety_class', | |
| and optionally 'log_yield_strength'. | |
| Returns: | |
| Dict with 'total', 'regression', 'classification', 'physics' losses. | |
| """ | |
| # Regression losses | |
| stress_loss = self._regression_loss(predictions["stress"], targets["log_stress"]) | |
| defl_loss = self._regression_loss(predictions["deflection"], targets["log_deflection"]) | |
| regression_loss = stress_loss + defl_loss | |
| # Classification loss | |
| cls_loss = self.ce_loss(predictions["safety"], targets["safety_class"]) | |
| # Physics penalty | |
| phys_loss = self._physics_penalty( | |
| predictions["stress"], | |
| predictions["deflection"], | |
| predictions["safety"], | |
| targets, | |
| ) | |
| total = regression_loss + self.cls_weight * cls_loss + self.phys_weight * phys_loss | |
| return { | |
| "total": total, | |
| "regression": regression_loss, | |
| "classification": cls_loss, | |
| "physics": phys_loss, | |
| } | |