LiquidFlow / liquidflow /losses.py
krystv's picture
Add losses.py, sampling.py, train.py, smoke_test.py
2b4ad8c verified
"""
Physics-Informed Loss for LiquidFlow.
Combines:
1. Rectified Flow Matching loss: ||v_θ(x_t, t) - (x_1 - x_0)||²
2. Physics residual: smoothness + continuity constraints on generated images
3. ConFIG gradient conflict-free update (from TUM-PBS, 2024)
4. Adaptive loss weighting (from PINN gradient pathology research, Wang et al. 2020)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhysicsInformedFlowLoss(nn.Module):
"""
Multi-objective loss for physics-informed flow matching.
L = L_flow + λ_smooth * L_smooth + λ_tv * L_tv
Where:
- L_flow: Rectified flow matching MSE
- L_smooth: Laplacian smoothness (heat equation steady state)
- L_tv: Total variation for edge preservation
"""
def __init__(self, lambda_smooth=0.01, lambda_tv=0.001,
lambda_continuity=0.005, use_adaptive_weights=True):
super().__init__()
self.lambda_smooth = lambda_smooth
self.lambda_tv = lambda_tv
self.lambda_continuity = lambda_continuity
self.use_adaptive_weights = use_adaptive_weights
self.register_buffer('flow_grad_norm', torch.tensor(1.0))
self.register_buffer('physics_grad_norm', torch.tensor(1.0))
self.ema_decay = 0.99
def flow_matching_loss(self, v_pred, x0, x1, t):
target = x1 - x0
return F.mse_loss(v_pred, target)
def smoothness_loss(self, x_pred):
lap_h = x_pred[:, :, 2:, :] - 2 * x_pred[:, :, 1:-1, :] + x_pred[:, :, :-2, :]
lap_w = x_pred[:, :, :, 2:] - 2 * x_pred[:, :, :, 1:-1] + x_pred[:, :, :, :-2]
h_min = min(lap_h.shape[2], lap_w.shape[2])
w_min = min(lap_h.shape[3], lap_w.shape[3])
laplacian = lap_h[:, :, :h_min, :w_min] + lap_w[:, :, :h_min, :w_min]
return (laplacian ** 2).mean()
def total_variation_loss(self, x_pred):
diff_h = torch.abs(x_pred[:, :, 1:, :] - x_pred[:, :, :-1, :])
diff_w = torch.abs(x_pred[:, :, :, 1:] - x_pred[:, :, :, :-1])
return diff_h.mean() + diff_w.mean()
def forward(self, v_pred, x0, x1, t, x_pred_clean=None, step=0):
loss_flow = self.flow_matching_loss(v_pred, x0, x1, t)
if x_pred_clean is None:
t_expand = t.view(-1, 1, 1, 1)
x_t = t_expand * x1 + (1 - t_expand) * x0
x_pred_clean = x_t + v_pred * (1 - t_expand)
warmup_steps = 500
physics_weight = min(1.0, step / warmup_steps) if step < warmup_steps else 1.0
loss_smooth = self.smoothness_loss(x_pred_clean)
loss_tv = self.total_variation_loss(x_pred_clean)
total = (loss_flow
+ physics_weight * self.lambda_smooth * loss_smooth
+ physics_weight * self.lambda_tv * loss_tv)
losses = {
'total': total, 'flow': loss_flow,
'smooth': loss_smooth, 'tv': loss_tv,
'physics_weight': torch.tensor(physics_weight),
}
return total, losses
class EMAModel:
"""Exponential Moving Average of model parameters."""
def __init__(self, model, decay=0.9999, warmup_steps=1000):
self.decay = decay
self.warmup_steps = warmup_steps
self.step = 0
self.shadow = {}
self.backup = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
def update(self, model):
self.step += 1
decay = min(self.decay, (1 + self.step) / (10 + self.step))
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = decay * self.shadow[name] + (1 - decay) * param.data
def apply_shadow(self, model):
self.backup = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.backup[name] = param.data.clone()
param.data.copy_(self.shadow[name])
def restore(self, model):
for name, param in model.named_parameters():
if param.requires_grad and name in self.backup:
param.data.copy_(self.backup[name])
self.backup = {}
def state_dict(self):
return {'shadow': self.shadow, 'step': self.step}
def load_state_dict(self, state_dict):
self.shadow = state_dict['shadow']
self.step = state_dict['step']