""" Physics-Informed Loss for LiquidFlow. Combines: 1. Rectified Flow Matching loss: ||v_θ(x_t, t) - (x_1 - x_0)||² 2. Physics residual: smoothness + continuity constraints on generated images 3. ConFIG gradient conflict-free update (from TUM-PBS, 2024) 4. Adaptive loss weighting (from PINN gradient pathology research, Wang et al. 2020) """ import torch import torch.nn as nn import torch.nn.functional as F class PhysicsInformedFlowLoss(nn.Module): """ Multi-objective loss for physics-informed flow matching. L = L_flow + λ_smooth * L_smooth + λ_tv * L_tv Where: - L_flow: Rectified flow matching MSE - L_smooth: Laplacian smoothness (heat equation steady state) - L_tv: Total variation for edge preservation """ def __init__(self, lambda_smooth=0.01, lambda_tv=0.001, lambda_continuity=0.005, use_adaptive_weights=True): super().__init__() self.lambda_smooth = lambda_smooth self.lambda_tv = lambda_tv self.lambda_continuity = lambda_continuity self.use_adaptive_weights = use_adaptive_weights self.register_buffer('flow_grad_norm', torch.tensor(1.0)) self.register_buffer('physics_grad_norm', torch.tensor(1.0)) self.ema_decay = 0.99 def flow_matching_loss(self, v_pred, x0, x1, t): target = x1 - x0 return F.mse_loss(v_pred, target) def smoothness_loss(self, x_pred): lap_h = x_pred[:, :, 2:, :] - 2 * x_pred[:, :, 1:-1, :] + x_pred[:, :, :-2, :] lap_w = x_pred[:, :, :, 2:] - 2 * x_pred[:, :, :, 1:-1] + x_pred[:, :, :, :-2] h_min = min(lap_h.shape[2], lap_w.shape[2]) w_min = min(lap_h.shape[3], lap_w.shape[3]) laplacian = lap_h[:, :, :h_min, :w_min] + lap_w[:, :, :h_min, :w_min] return (laplacian ** 2).mean() def total_variation_loss(self, x_pred): diff_h = torch.abs(x_pred[:, :, 1:, :] - x_pred[:, :, :-1, :]) diff_w = torch.abs(x_pred[:, :, :, 1:] - x_pred[:, :, :, :-1]) return diff_h.mean() + diff_w.mean() def forward(self, v_pred, x0, x1, t, x_pred_clean=None, step=0): loss_flow = self.flow_matching_loss(v_pred, x0, x1, t) if x_pred_clean is None: t_expand = t.view(-1, 1, 1, 1) x_t = t_expand * x1 + (1 - t_expand) * x0 x_pred_clean = x_t + v_pred * (1 - t_expand) warmup_steps = 500 physics_weight = min(1.0, step / warmup_steps) if step < warmup_steps else 1.0 loss_smooth = self.smoothness_loss(x_pred_clean) loss_tv = self.total_variation_loss(x_pred_clean) total = (loss_flow + physics_weight * self.lambda_smooth * loss_smooth + physics_weight * self.lambda_tv * loss_tv) losses = { 'total': total, 'flow': loss_flow, 'smooth': loss_smooth, 'tv': loss_tv, 'physics_weight': torch.tensor(physics_weight), } return total, losses class EMAModel: """Exponential Moving Average of model parameters.""" def __init__(self, model, decay=0.9999, warmup_steps=1000): self.decay = decay self.warmup_steps = warmup_steps self.step = 0 self.shadow = {} self.backup = {} for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() def update(self, model): self.step += 1 decay = min(self.decay, (1 + self.step) / (10 + self.step)) for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = decay * self.shadow[name] + (1 - decay) * param.data def apply_shadow(self, model): self.backup = {} for name, param in model.named_parameters(): if param.requires_grad: self.backup[name] = param.data.clone() param.data.copy_(self.shadow[name]) def restore(self, model): for name, param in model.named_parameters(): if param.requires_grad and name in self.backup: param.data.copy_(self.backup[name]) self.backup = {} def state_dict(self): return {'shadow': self.shadow, 'step': self.step} def load_state_dict(self, state_dict): self.shadow = state_dict['shadow'] self.step = state_dict['step']