File size: 4,489 Bytes
2b4ad8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Physics-Informed Loss for LiquidFlow.

Combines:
1. Rectified Flow Matching loss: ||v_θ(x_t, t) - (x_1 - x_0)||²
2. Physics residual: smoothness + continuity constraints on generated images
3. ConFIG gradient conflict-free update (from TUM-PBS, 2024)
4. Adaptive loss weighting (from PINN gradient pathology research, Wang et al. 2020)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class PhysicsInformedFlowLoss(nn.Module):
    """
    Multi-objective loss for physics-informed flow matching.
    
    L = L_flow + λ_smooth * L_smooth + λ_tv * L_tv
    
    Where:
    - L_flow: Rectified flow matching MSE
    - L_smooth: Laplacian smoothness (heat equation steady state)  
    - L_tv: Total variation for edge preservation
    """
    
    def __init__(self, lambda_smooth=0.01, lambda_tv=0.001,
                 lambda_continuity=0.005, use_adaptive_weights=True):
        super().__init__()
        self.lambda_smooth = lambda_smooth
        self.lambda_tv = lambda_tv
        self.lambda_continuity = lambda_continuity
        self.use_adaptive_weights = use_adaptive_weights
        self.register_buffer('flow_grad_norm', torch.tensor(1.0))
        self.register_buffer('physics_grad_norm', torch.tensor(1.0))
        self.ema_decay = 0.99
    
    def flow_matching_loss(self, v_pred, x0, x1, t):
        target = x1 - x0
        return F.mse_loss(v_pred, target)
    
    def smoothness_loss(self, x_pred):
        lap_h = x_pred[:, :, 2:, :] - 2 * x_pred[:, :, 1:-1, :] + x_pred[:, :, :-2, :]
        lap_w = x_pred[:, :, :, 2:] - 2 * x_pred[:, :, :, 1:-1] + x_pred[:, :, :, :-2]
        h_min = min(lap_h.shape[2], lap_w.shape[2])
        w_min = min(lap_h.shape[3], lap_w.shape[3])
        laplacian = lap_h[:, :, :h_min, :w_min] + lap_w[:, :, :h_min, :w_min]
        return (laplacian ** 2).mean()
    
    def total_variation_loss(self, x_pred):
        diff_h = torch.abs(x_pred[:, :, 1:, :] - x_pred[:, :, :-1, :])
        diff_w = torch.abs(x_pred[:, :, :, 1:] - x_pred[:, :, :, :-1])
        return diff_h.mean() + diff_w.mean()
    
    def forward(self, v_pred, x0, x1, t, x_pred_clean=None, step=0):
        loss_flow = self.flow_matching_loss(v_pred, x0, x1, t)
        
        if x_pred_clean is None:
            t_expand = t.view(-1, 1, 1, 1)
            x_t = t_expand * x1 + (1 - t_expand) * x0
            x_pred_clean = x_t + v_pred * (1 - t_expand)
        
        warmup_steps = 500
        physics_weight = min(1.0, step / warmup_steps) if step < warmup_steps else 1.0
        
        loss_smooth = self.smoothness_loss(x_pred_clean)
        loss_tv = self.total_variation_loss(x_pred_clean)
        
        total = (loss_flow
                 + physics_weight * self.lambda_smooth * loss_smooth
                 + physics_weight * self.lambda_tv * loss_tv)
        
        losses = {
            'total': total, 'flow': loss_flow,
            'smooth': loss_smooth, 'tv': loss_tv,
            'physics_weight': torch.tensor(physics_weight),
        }
        return total, losses


class EMAModel:
    """Exponential Moving Average of model parameters."""
    
    def __init__(self, model, decay=0.9999, warmup_steps=1000):
        self.decay = decay
        self.warmup_steps = warmup_steps
        self.step = 0
        self.shadow = {}
        self.backup = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
    
    def update(self, model):
        self.step += 1
        decay = min(self.decay, (1 + self.step) / (10 + self.step))
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = decay * self.shadow[name] + (1 - decay) * param.data
    
    def apply_shadow(self, model):
        self.backup = {}
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data.copy_(self.shadow[name])
    
    def restore(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad and name in self.backup:
                param.data.copy_(self.backup[name])
        self.backup = {}
    
    def state_dict(self):
        return {'shadow': self.shadow, 'step': self.step}
    
    def load_state_dict(self, state_dict):
        self.shadow = state_dict['shadow']
        self.step = state_dict['step']