""" Physics-Informed Regularization for LiquidFlow. CORRECTED VERSION: fixed intensity tracking, proper buffer handling. Pattern from: Bastek & Sun (ICLR 2025) - Physics losses computed on estimated x̂₀ during training - Zero cost at inference - Acts as implicit regularizer against artifacts """ import torch import torch.nn as nn import torch.nn.functional as F class PhysicsRegularizer(nn.Module): """ Physics-informed regularizer for diffusion training. Computed on estimated clean sample x̂₀ (DDIM one-step estimate). All losses are differentiable through the noise predictor. """ def __init__(self, tv_weight=0.01, cons_weight=0.001, spec_weight=0.01, grad_weight=0.001): super().__init__() self.tv_weight = tv_weight self.cons_weight = cons_weight self.spec_weight = spec_weight self.grad_weight = grad_weight # EMA intensity tracking self.register_buffer('intensity_ema', torch.tensor(0.0)) self.register_buffer('step_count', torch.tensor(0, dtype=torch.long)) def total_variation(self, x): """L1 total variation: encourages spatial smoothness.""" diff_h = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) diff_w = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) return diff_h.mean() + diff_w.mean() def conservation_intensity(self, x): """Penalize deviation from running mean intensity.""" batch_mean = x.mean() if self.training: with torch.no_grad(): self.step_count += 1 alpha = min(0.99, 1.0 - 1.0 / (self.step_count.float() + 1)) self.intensity_ema = alpha * self.intensity_ema + (1 - alpha) * batch_mean # Only activate after warmup (100 steps) if self.step_count > 100: return (batch_mean - self.intensity_ema.detach()) ** 2 return torch.zeros(1, device=x.device, requires_grad=True).squeeze() def spectral_regularizer(self, x): """Penalize high-frequency energy (anti-checkerboard).""" B, C, H, W = x.shape # 2D FFT x_fft = torch.fft.rfft2(x, norm='ortho') mag = torch.abs(x_fft) # High-frequency mask: upper-right quadrant of frequency space # For rfft2, output shape is [B, C, H, W//2+1] freq_h = torch.arange(H, device=x.device).float() freq_w = torch.arange(W // 2 + 1, device=x.device).float() # Normalize frequencies to [0, 1] freq_h = torch.min(freq_h, H - freq_h) / (H / 2) freq_w = freq_w / (W / 2) # Distance from DC (center) dist = torch.sqrt(freq_h.unsqueeze(1) ** 2 + freq_w.unsqueeze(0) ** 2) # High frequency: distance > 0.5 (half Nyquist) high_mask = (dist > 0.5).float() high_energy = (mag * high_mask.unsqueeze(0).unsqueeze(0)).mean() return high_energy def gradient_penalty(self, x): """Sobolev L2 gradient penalty.""" grad_h = x[:, :, 1:, :] - x[:, :, :-1, :] grad_w = x[:, :, :, 1:] - x[:, :, :, :-1] return (grad_h ** 2).mean() + (grad_w ** 2).mean() def forward(self, x0_hat, x_ref=None): """ Args: x0_hat: Estimated clean image [B, C, H, W] x_ref: Ground truth (unused, kept for API compat) Returns: total_loss, loss_dict """ losses = {} total = torch.zeros(1, device=x0_hat.device, requires_grad=True).squeeze() if self.tv_weight > 0: tv = self.total_variation(x0_hat) losses['tv'] = tv total = total + self.tv_weight * tv if self.cons_weight > 0: cons = self.conservation_intensity(x0_hat) losses['cons'] = cons total = total + self.cons_weight * cons if self.spec_weight > 0: spec = self.spectral_regularizer(x0_hat) losses['spec'] = spec total = total + self.spec_weight * spec if self.grad_weight > 0: grad = self.gradient_penalty(x0_hat) losses['grad'] = grad total = total + self.grad_weight * grad return total, losses class DDIMEstimator: """DDIM one-step clean sample estimation.""" @staticmethod def estimate_x0(x_t, eps_pred, alpha_bar_t): """ x̂₀ = (x_t - √(1-ᾱ_t) · ε_pred) / √(ᾱ_t) Args: x_t: [B, C, H, W] eps_pred: [B, C, H, W] alpha_bar_t: [B] — cumulative alpha at timestep t """ a = alpha_bar_t.reshape(-1, 1, 1, 1) x0_hat = (x_t - torch.sqrt(1 - a) * eps_pred) / (torch.sqrt(a) + 1e-8) # Clamp to prevent extreme values early in training return x0_hat.clamp(-5, 5)