| """ |
| Physics-Informed Loss for LiquidFlow. |
| |
| Combines: |
| 1. Rectified Flow Matching loss: ||v_θ(x_t, t) - (x_1 - x_0)||² |
| 2. Physics residual: smoothness + continuity constraints on generated images |
| 3. ConFIG gradient conflict-free update (from TUM-PBS, 2024) |
| 4. Adaptive loss weighting (from PINN gradient pathology research, Wang et al. 2020) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class PhysicsInformedFlowLoss(nn.Module): |
| """ |
| Multi-objective loss for physics-informed flow matching. |
| |
| L = L_flow + λ_smooth * L_smooth + λ_tv * L_tv |
| |
| Where: |
| - L_flow: Rectified flow matching MSE |
| - L_smooth: Laplacian smoothness (heat equation steady state) |
| - L_tv: Total variation for edge preservation |
| """ |
| |
| def __init__(self, lambda_smooth=0.01, lambda_tv=0.001, |
| lambda_continuity=0.005, use_adaptive_weights=True): |
| super().__init__() |
| self.lambda_smooth = lambda_smooth |
| self.lambda_tv = lambda_tv |
| self.lambda_continuity = lambda_continuity |
| self.use_adaptive_weights = use_adaptive_weights |
| self.register_buffer('flow_grad_norm', torch.tensor(1.0)) |
| self.register_buffer('physics_grad_norm', torch.tensor(1.0)) |
| self.ema_decay = 0.99 |
| |
| def flow_matching_loss(self, v_pred, x0, x1, t): |
| target = x1 - x0 |
| return F.mse_loss(v_pred, target) |
| |
| def smoothness_loss(self, x_pred): |
| lap_h = x_pred[:, :, 2:, :] - 2 * x_pred[:, :, 1:-1, :] + x_pred[:, :, :-2, :] |
| lap_w = x_pred[:, :, :, 2:] - 2 * x_pred[:, :, :, 1:-1] + x_pred[:, :, :, :-2] |
| h_min = min(lap_h.shape[2], lap_w.shape[2]) |
| w_min = min(lap_h.shape[3], lap_w.shape[3]) |
| laplacian = lap_h[:, :, :h_min, :w_min] + lap_w[:, :, :h_min, :w_min] |
| return (laplacian ** 2).mean() |
| |
| def total_variation_loss(self, x_pred): |
| diff_h = torch.abs(x_pred[:, :, 1:, :] - x_pred[:, :, :-1, :]) |
| diff_w = torch.abs(x_pred[:, :, :, 1:] - x_pred[:, :, :, :-1]) |
| return diff_h.mean() + diff_w.mean() |
| |
| def forward(self, v_pred, x0, x1, t, x_pred_clean=None, step=0): |
| loss_flow = self.flow_matching_loss(v_pred, x0, x1, t) |
| |
| if x_pred_clean is None: |
| t_expand = t.view(-1, 1, 1, 1) |
| x_t = t_expand * x1 + (1 - t_expand) * x0 |
| x_pred_clean = x_t + v_pred * (1 - t_expand) |
| |
| warmup_steps = 500 |
| physics_weight = min(1.0, step / warmup_steps) if step < warmup_steps else 1.0 |
| |
| loss_smooth = self.smoothness_loss(x_pred_clean) |
| loss_tv = self.total_variation_loss(x_pred_clean) |
| |
| total = (loss_flow |
| + physics_weight * self.lambda_smooth * loss_smooth |
| + physics_weight * self.lambda_tv * loss_tv) |
| |
| losses = { |
| 'total': total, 'flow': loss_flow, |
| 'smooth': loss_smooth, 'tv': loss_tv, |
| 'physics_weight': torch.tensor(physics_weight), |
| } |
| return total, losses |
|
|
|
|
| class EMAModel: |
| """Exponential Moving Average of model parameters.""" |
| |
| def __init__(self, model, decay=0.9999, warmup_steps=1000): |
| self.decay = decay |
| self.warmup_steps = warmup_steps |
| self.step = 0 |
| self.shadow = {} |
| self.backup = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| self.shadow[name] = param.data.clone() |
| |
| def update(self, model): |
| self.step += 1 |
| decay = min(self.decay, (1 + self.step) / (10 + self.step)) |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| self.shadow[name] = decay * self.shadow[name] + (1 - decay) * param.data |
| |
| def apply_shadow(self, model): |
| self.backup = {} |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| self.backup[name] = param.data.clone() |
| param.data.copy_(self.shadow[name]) |
| |
| def restore(self, model): |
| for name, param in model.named_parameters(): |
| if param.requires_grad and name in self.backup: |
| param.data.copy_(self.backup[name]) |
| self.backup = {} |
| |
| def state_dict(self): |
| return {'shadow': self.shadow, 'step': self.step} |
| |
| def load_state_dict(self, state_dict): |
| self.shadow = state_dict['shadow'] |
| self.step = state_dict['step'] |