File size: 4,934 Bytes
1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f 1a3345c 0cf988f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
Physics-Informed Regularization for LiquidFlow.
CORRECTED VERSION: fixed intensity tracking, proper buffer handling.
Pattern from: Bastek & Sun (ICLR 2025)
- Physics losses computed on estimated x̂₀ during training
- Zero cost at inference
- Acts as implicit regularizer against artifacts
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class PhysicsRegularizer(nn.Module):
"""
Physics-informed regularizer for diffusion training.
Computed on estimated clean sample x̂₀ (DDIM one-step estimate).
All losses are differentiable through the noise predictor.
"""
def __init__(self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001):
super().__init__()
self.tv_weight = tv_weight
self.cons_weight = cons_weight
self.spec_weight = spec_weight
self.grad_weight = grad_weight
# EMA intensity tracking
self.register_buffer('intensity_ema', torch.tensor(0.0))
self.register_buffer('step_count', torch.tensor(0, dtype=torch.long))
def total_variation(self, x):
"""L1 total variation: encourages spatial smoothness."""
diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
return diff_h.mean() + diff_w.mean()
def conservation_intensity(self, x):
"""Penalize deviation from running mean intensity."""
batch_mean = x.mean()
if self.training:
with torch.no_grad():
self.step_count += 1
alpha = min(0.99, 1.0 - 1.0 / (self.step_count.float() + 1))
self.intensity_ema = alpha * self.intensity_ema + (1 - alpha) * batch_mean
# Only activate after warmup (100 steps)
if self.step_count > 100:
return (batch_mean - self.intensity_ema.detach()) ** 2
return torch.zeros(1, device=x.device, requires_grad=True).squeeze()
def spectral_regularizer(self, x):
"""Penalize high-frequency energy (anti-checkerboard)."""
B, C, H, W = x.shape
# 2D FFT
x_fft = torch.fft.rfft2(x, norm='ortho')
mag = torch.abs(x_fft)
# High-frequency mask: upper-right quadrant of frequency space
# For rfft2, output shape is [B, C, H, W//2+1]
freq_h = torch.arange(H, device=x.device).float()
freq_w = torch.arange(W // 2 + 1, device=x.device).float()
# Normalize frequencies to [0, 1]
freq_h = torch.min(freq_h, H - freq_h) / (H / 2)
freq_w = freq_w / (W / 2)
# Distance from DC (center)
dist = torch.sqrt(freq_h.unsqueeze(1) ** 2 + freq_w.unsqueeze(0) ** 2)
# High frequency: distance > 0.5 (half Nyquist)
high_mask = (dist > 0.5).float()
high_energy = (mag * high_mask.unsqueeze(0).unsqueeze(0)).mean()
return high_energy
def gradient_penalty(self, x):
"""Sobolev L2 gradient penalty."""
grad_h = x[:, :, 1:, :] - x[:, :, :-1, :]
grad_w = x[:, :, :, 1:] - x[:, :, :, :-1]
return (grad_h ** 2).mean() + (grad_w ** 2).mean()
def forward(self, x0_hat, x_ref=None):
"""
Args:
x0_hat: Estimated clean image [B, C, H, W]
x_ref: Ground truth (unused, kept for API compat)
Returns:
total_loss, loss_dict
"""
losses = {}
total = torch.zeros(1, device=x0_hat.device, requires_grad=True).squeeze()
if self.tv_weight > 0:
tv = self.total_variation(x0_hat)
losses['tv'] = tv
total = total + self.tv_weight * tv
if self.cons_weight > 0:
cons = self.conservation_intensity(x0_hat)
losses['cons'] = cons
total = total + self.cons_weight * cons
if self.spec_weight > 0:
spec = self.spectral_regularizer(x0_hat)
losses['spec'] = spec
total = total + self.spec_weight * spec
if self.grad_weight > 0:
grad = self.gradient_penalty(x0_hat)
losses['grad'] = grad
total = total + self.grad_weight * grad
return total, losses
class DDIMEstimator:
"""DDIM one-step clean sample estimation."""
@staticmethod
def estimate_x0(x_t, eps_pred, alpha_bar_t):
"""
x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t)
Args:
x_t: [B, C, H, W]
eps_pred: [B, C, H, W]
alpha_bar_t: [B] — cumulative alpha at timestep t
"""
a = alpha_bar_t.reshape(-1, 1, 1, 1)
x0_hat = (x_t - torch.sqrt(1 - a) * eps_pred) / (torch.sqrt(a) + 1e-8)
# Clamp to prevent extreme values early in training
return x0_hat.clamp(-5, 5)
|