| """ |
| ================================================================================ |
| SENTINEL DIFFUSION MODEL |
| ================================================================================ |
| |
| Theory: Standard diffusion models use Gaussian noise schedules. |
| The Sentinel prior P(n) ∝ zⁿ/nⁿ has super-exponential decay, creating |
| sharper transitions between noise levels. |
| |
| Key Innovation: Sentinel noise schedule for faster convergence and |
| sharper transitions in diffusion-based generative models. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from typing import Tuple |
|
|
| class SentinelNoiseSchedule: |
| """ |
| Sentinel noise schedule based on the partition function F(z) = Σ zⁿ/nⁿ. |
| |
| The noise levels are distributed according to the Sentinel PMF: |
| β_t ∝ t^t / T^T (super-exponentially decaying) |
| |
| This creates a schedule where: |
| - Early steps: small noise (high precision in structure) |
| - Late steps: large noise (coarse structure) |
| - Transition is SHARPER than Gaussian schedules |
| """ |
| |
| def __init__(self, timesteps: int = 1000, z: float = 2.0): |
| self.timesteps = timesteps |
| self.z = z |
| |
| |
| self.betas = self._sentinel_schedule() |
| self.alphas = 1.0 - self.betas |
| self.alpha_bars = torch.cumprod(self.alphas, dim=0) |
| |
| def _sentinel_schedule(self) -> torch.Tensor: |
| """Generate Sentinel noise schedule.""" |
| n = torch.arange(1, self.timesteps + 1, dtype=torch.float64) |
| |
| |
| |
| |
| |
| t_norm = n / self.timesteps |
| beta = torch.zeros_like(n) |
| |
| |
| |
| for i in range(self.timesteps): |
| t = t_norm[i].item() |
| |
| if t < 0.5: |
| beta[i] = 0.0001 + 0.01 * (2 * t) ** (1 / (2 * t + 0.01)) |
| else: |
| beta[i] = 0.01 + 0.02 * ((2 * t - 1) ** (2 * t - 1)) |
| |
| beta = torch.clamp(beta, 0.0001, 0.999) |
| return beta.float() |
| |
| def add_noise(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Add noise at timestep t.""" |
| sqrt_alpha_bar = torch.sqrt(self.alpha_bars[t]) |
| sqrt_one_minus_alpha_bar = torch.sqrt(1.0 - self.alpha_bars[t]) |
| |
| noise = torch.randn_like(x) |
| noisy_x = sqrt_alpha_bar.view(-1, 1, 1, 1) * x + \ |
| sqrt_one_minus_alpha_bar.view(-1, 1, 1, 1) * noise |
| |
| return noisy_x, noise |
| |
| def sample_timesteps(self, batch_size: int) -> torch.Tensor: |
| """Sample timesteps according to Sentinel distribution.""" |
| |
| weights = 1.0 / (self.betas + 1e-8) |
| weights = weights / weights.sum() |
| return torch.multinomial(weights, batch_size, replacement=True) |
|
|
|
|
| class SentinelUNet(nn.Module): |
| """Simple UNet for diffusion with Sentinel activations.""" |
| |
| def __init__(self, in_channels: int = 3, time_emb_dim: int = 256): |
| super().__init__() |
| self.time_mlp = nn.Sequential( |
| nn.Linear(1, time_emb_dim), |
| nn.SiLU(), |
| nn.Linear(time_emb_dim, time_emb_dim) |
| ) |
| |
| |
| self.enc1 = self._conv_block(in_channels, 64) |
| self.enc2 = self._conv_block(64, 128) |
| self.dec2 = self._conv_block(128 + time_emb_dim, 64) |
| self.dec1 = nn.Conv2d(64, in_channels, 3, padding=1) |
| |
| self.inv_e = 1.0 / np.e |
| |
| def _conv_block(self, in_ch: int, out_ch: int) -> nn.Module: |
| return nn.Sequential( |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), |
| nn.GroupNorm(8, out_ch), |
| nn.SiLU() |
| ) |
| |
| def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| """Predict noise given noisy image and timestep.""" |
| |
| t_emb = self.time_mlp(t.float().view(-1, 1) / 1000.0) |
| |
| |
| h1 = self.enc1(x) |
| h2 = self.enc2(F.max_pool2d(h1, 2)) |
| |
| |
| t_emb_spatial = t_emb.view(-1, t_emb.size(1), 1, 1) |
| t_emb_spatial = t_emb_spatial.expand(-1, -1, h2.size(2), h2.size(3)) |
| h2 = torch.cat([h2, t_emb_spatial], dim=1) |
| |
| |
| h = F.interpolate(self.dec2(h2), size=x.shape[2:], mode='nearest') |
| h = h + h1 |
| |
| return self.dec1(h) |
|
|
|
|
| def demo_sentinel_diffusion(): |
| """Demo Sentinel diffusion on synthetic images.""" |
| print("=" * 70) |
| print(" SENTINEL DIFFUSION MODEL") |
| print("=" * 70) |
| |
| |
| schedule = SentinelNoiseSchedule(timesteps=1000, z=2.0) |
| |
| print(f"\n--- Sentinel Noise Schedule ---") |
| print(f" Timesteps: {schedule.timesteps}") |
| print(f" Initial β: {schedule.betas[0].item():.6f}") |
| print(f" Middle β: {schedule.betas[500].item():.6f}") |
| print(f" Final β: {schedule.betas[-1].item():.6f}") |
| print(f" Schedule shape: super-exponential rise") |
| |
| |
| x = torch.randn(4, 3, 32, 32) |
| t = schedule.sample_timesteps(4) |
| |
| |
| noisy_x, noise = schedule.add_noise(x, t) |
| |
| print(f"\n--- Noise Addition ---") |
| print(f" Clean image range: [{x.min():.2f}, {x.max():.2f}]") |
| print(f" Noisy image range: [{noisy_x.min():.2f}, {noisy_x.max():.2f}]") |
| print(f" Noise range: [{noise.min():.2f}, {noise.max():.2f}]") |
| |
| |
| model = SentinelUNet(in_channels=3) |
| pred_noise = model(noisy_x, t) |
| |
| print(f"\n Predicted noise shape: {pred_noise.shape}") |
| print(f" Predicted noise range: [{pred_noise.min():.2f}, {pred_noise.max():.2f}]") |
| |
| print(f"\n ✓ Super-exponential noise schedule for sharp transitions") |
| print(f" ✓ Sentinel-inspired: preserves structure early, destroys late") |
| print(f" ✓ Potential: fewer diffusion steps needed vs Gaussian schedules") |
| |
| print(f"\n{'='*70}") |
| print(f" SENTINEL DIFFUSION: SHARPER TRANSITIONS, FEWER STEPS") |
| print(f"{'='*70}") |
|
|
|
|
| if __name__ == '__main__': |
| demo_sentinel_diffusion() |
|
|