File size: 4,934 Bytes
1a3345c
 
0cf988f
1a3345c
0cf988f
 
 
 
1a3345c
 
 
 
 
 
 
 
 
0cf988f
1a3345c
0cf988f
 
1a3345c
 
0cf988f
1a3345c
 
 
 
 
 
0cf988f
 
 
1a3345c
 
0cf988f
1a3345c
 
 
 
 
0cf988f
1a3345c
 
 
 
0cf988f
 
 
 
 
 
 
 
1a3345c
 
0cf988f
 
1a3345c
 
0cf988f
 
1a3345c
0cf988f
 
 
 
1a3345c
0cf988f
 
 
1a3345c
0cf988f
 
1a3345c
0cf988f
 
1a3345c
0cf988f
 
1a3345c
 
0cf988f
1a3345c
 
0cf988f
1a3345c
 
 
 
 
0cf988f
1a3345c
0cf988f
1a3345c
 
0cf988f
1a3345c
 
0cf988f
 
 
1a3345c
 
0cf988f
 
 
1a3345c
 
0cf988f
 
 
1a3345c
 
0cf988f
 
 
1a3345c
 
 
 
 
0cf988f
1a3345c
 
 
 
0cf988f
1a3345c
 
0cf988f
 
 
1a3345c
0cf988f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
"""
Physics-Informed Regularization for LiquidFlow.
CORRECTED VERSION: fixed intensity tracking, proper buffer handling.

Pattern from: Bastek & Sun (ICLR 2025)
- Physics losses computed on estimated x̂₀ during training
- Zero cost at inference
- Acts as implicit regularizer against artifacts
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


class PhysicsRegularizer(nn.Module):
    """
    Physics-informed regularizer for diffusion training.
    
    Computed on estimated clean sample x̂₀ (DDIM one-step estimate).
    All losses are differentiable through the noise predictor.
    """
    
    def __init__(self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001):
        super().__init__()
        self.tv_weight = tv_weight
        self.cons_weight = cons_weight
        self.spec_weight = spec_weight
        self.grad_weight = grad_weight
        
        # EMA intensity tracking
        self.register_buffer('intensity_ema', torch.tensor(0.0))
        self.register_buffer('step_count', torch.tensor(0, dtype=torch.long))
    
    def total_variation(self, x):
        """L1 total variation: encourages spatial smoothness."""
        diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
        diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
        return diff_h.mean() + diff_w.mean()
    
    def conservation_intensity(self, x):
        """Penalize deviation from running mean intensity."""
        batch_mean = x.mean()
        
        if self.training:
            with torch.no_grad():
                self.step_count += 1
                alpha = min(0.99, 1.0 - 1.0 / (self.step_count.float() + 1))
                self.intensity_ema = alpha * self.intensity_ema + (1 - alpha) * batch_mean
        
        # Only activate after warmup (100 steps)
        if self.step_count > 100:
            return (batch_mean - self.intensity_ema.detach()) ** 2
        return torch.zeros(1, device=x.device, requires_grad=True).squeeze()
    
    def spectral_regularizer(self, x):
        """Penalize high-frequency energy (anti-checkerboard)."""
        B, C, H, W = x.shape
        
        # 2D FFT
        x_fft = torch.fft.rfft2(x, norm='ortho')
        mag = torch.abs(x_fft)
        
        # High-frequency mask: upper-right quadrant of frequency space
        # For rfft2, output shape is [B, C, H, W//2+1]
        freq_h = torch.arange(H, device=x.device).float()
        freq_w = torch.arange(W // 2 + 1, device=x.device).float()
        
        # Normalize frequencies to [0, 1]
        freq_h = torch.min(freq_h, H - freq_h) / (H / 2)
        freq_w = freq_w / (W / 2)
        
        # Distance from DC (center)
        dist = torch.sqrt(freq_h.unsqueeze(1) ** 2 + freq_w.unsqueeze(0) ** 2)
        
        # High frequency: distance > 0.5 (half Nyquist)
        high_mask = (dist > 0.5).float()
        
        high_energy = (mag * high_mask.unsqueeze(0).unsqueeze(0)).mean()
        return high_energy
    
    def gradient_penalty(self, x):
        """Sobolev L2 gradient penalty."""
        grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
        grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
        return (grad_h ** 2).mean() + (grad_w ** 2).mean()
    
    def forward(self, x0_hat, x_ref=None):
        """
        Args:
            x0_hat: Estimated clean image [B, C, H, W]
            x_ref: Ground truth (unused, kept for API compat)
        Returns:
            total_loss, loss_dict
        """
        losses = {}
        total = torch.zeros(1, device=x0_hat.device, requires_grad=True).squeeze()
        
        if self.tv_weight > 0:
            tv = self.total_variation(x0_hat)
            losses['tv'] = tv
            total = total + self.tv_weight * tv
        
        if self.cons_weight > 0:
            cons = self.conservation_intensity(x0_hat)
            losses['cons'] = cons
            total = total + self.cons_weight * cons
        
        if self.spec_weight > 0:
            spec = self.spectral_regularizer(x0_hat)
            losses['spec'] = spec
            total = total + self.spec_weight * spec
        
        if self.grad_weight > 0:
            grad = self.gradient_penalty(x0_hat)
            losses['grad'] = grad
            total = total + self.grad_weight * grad
        
        return total, losses


class DDIMEstimator:
    """DDIM one-step clean sample estimation."""
    
    @staticmethod
    def estimate_x0(x_t, eps_pred, alpha_bar_t):
        """
        x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
        
        Args:
            x_t: [B, C, H, W]
            eps_pred: [B, C, H, W]
            alpha_bar_t: [B] — cumulative alpha at timestep t
        """
        a = alpha_bar_t.reshape(-1, 1, 1, 1)
        x0_hat = (x_t - torch.sqrt(1 - a) * eps_pred) / (torch.sqrt(a) + 1e-8)
        # Clamp to prevent extreme values early in training
        return x0_hat.clamp(-5, 5)