| |
| """ |
| Centralized noise schedule manager for diffusion models. |
| |
| Supports three schedules: |
| 1. 'cosine': Cosine schedule (Nichol & Dhariwal 2021) |
| 2. 'linear_beta': Linear beta schedule (Ho et al. 2020) |
| 3. 'linear_interp': Linear interpolation - Flow Matching default |
| |
| All schedules return (alpha_t, sigma_t) such that: |
| x_t = alpha_t * x_0 + sigma_t * epsilon |
| alpha_t^2 + sigma_t^2 = 1 (variance preserving) |
| """ |
|
|
| import torch |
| import math |
| from typing import Tuple |
|
|
|
|
| class NoiseSchedule: |
| """ |
| Centralized noise schedule manager. |
| |
| Args: |
| schedule_type: One of ['cosine', 'linear_beta', 'linear_interp'] |
| """ |
| |
| def __init__(self, schedule_type: str = 'linear_interp'): |
| assert schedule_type in ['cosine', 'linear_beta', 'linear_interp'], \ |
| f"Unknown schedule: {schedule_type}. Must be one of ['cosine', 'linear_beta', 'linear_interp']" |
| self.schedule_type = schedule_type |
| |
| |
| self.beta_min = 0.0001 |
| self.beta_max = 0.02 |
| self.num_timesteps = 1000 |
| |
| |
| self.s = 0.008 |
| |
| def get_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Get (alpha_t, sigma_t) for given timesteps. |
| |
| Args: |
| t: Tensor of timesteps in [0, 1], shape (B,) |
| |
| Returns: |
| alpha_t: Shape (B,), coefficient for x_0 |
| sigma_t: Shape (B,), coefficient for epsilon |
| """ |
| if self.schedule_type == 'cosine': |
| return self._cosine_schedule(t) |
| elif self.schedule_type == 'linear_beta': |
| return self._linear_beta_schedule(t) |
| elif self.schedule_type == 'linear_interp': |
| return self._linear_interpolation(t) |
| |
| def _cosine_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Cosine schedule: alpha_bar_t = f(t) / f(0) |
| where f(t) = cos²((t + s)/(1 + s) * π/2) |
| |
| Reference: "Improved Denoising Diffusion Probabilistic Models" |
| (Nichol & Dhariwal, 2021) |
| |
| This schedule provides better conditioning than linear beta schedule, |
| especially at very small and very large t values. |
| """ |
| |
| f_t = torch.cos(((t + self.s) / (1 + self.s)) * math.pi * 0.5) ** 2 |
| |
| |
| f_0 = math.cos((self.s / (1 + self.s)) * math.pi * 0.5) ** 2 |
| |
| |
| alpha_bar_t = f_t / f_0 |
| |
| |
| alpha_bar_t = torch.clamp(alpha_bar_t, min=1e-8, max=1.0) |
| |
| |
| alpha_t = torch.sqrt(alpha_bar_t) |
| sigma_t = torch.sqrt(1 - alpha_bar_t) |
| |
| return alpha_t, sigma_t |
| |
| def _linear_beta_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Linear beta schedule: beta_t increases linearly from beta_min to beta_max |
| |
| Reference: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020) |
| |
| For continuous time t ∈ [0,1]: |
| beta(t) = beta_min + t * (beta_max - beta_min) |
| alpha_bar(t) = exp(-0.5 * integral_0^t beta(s) ds) |
| = exp(-0.5 * t * (beta_min + 0.5 * t * (beta_max - beta_min))) |
| """ |
| |
| |
| |
| integral_beta = self.beta_min * t + 0.5 * t * t * (self.beta_max - self.beta_min) |
| alpha_bar_t = torch.exp(-0.5 * integral_beta * self.num_timesteps) |
| |
| |
| alpha_t = torch.sqrt(alpha_bar_t) |
| sigma_t = torch.sqrt(1 - alpha_bar_t) |
| |
| return alpha_t, sigma_t |
| |
| def _linear_interpolation(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Linear interpolation: x_t = (1-t) * x_0 + t * epsilon |
| |
| This is the default for Flow Matching but NOT a proper DDPM schedule. |
| This is what the current implementation uses. |
| """ |
| alpha_t = 1 - t |
| sigma_t = t |
| return alpha_t, sigma_t |
| |
| def get_snr(self, t: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute signal-to-noise ratio (SNR) = alpha_t^2 / sigma_t^2 |
| |
| Useful for: |
| 1. Time warping between different schedules |
| 2. Analysis and visualization |
| |
| Args: |
| t: Tensor of timesteps in [0, 1] |
| |
| Returns: |
| snr: Signal-to-noise ratio at each timestep |
| """ |
| alpha_t, sigma_t = self.get_schedule(t) |
| snr = (alpha_t ** 2) / (sigma_t ** 2 + 1e-8) |
| return snr |
| |
| def alpha_to_time(self, alpha: torch.Tensor, num_steps: int = 100) -> torch.Tensor: |
| """ |
| Inverse mapping: given alpha, find t |
| |
| Used for inference when you want to specify noise levels directly. |
| Uses binary search since schedules are monotonic. |
| |
| Args: |
| alpha: Desired alpha values |
| num_steps: Number of steps for binary search |
| |
| Returns: |
| t: Corresponding timesteps |
| """ |
| device = alpha.device |
| |
| |
| t_candidates = torch.linspace(0, 1, num_steps, device=device) |
| alpha_candidates, _ = self.get_schedule(t_candidates) |
| |
| |
| distances = torch.abs(alpha_candidates.unsqueeze(0) - alpha.unsqueeze(1)) |
| indices = torch.argmin(distances, dim=1) |
| t = t_candidates[indices] |
| |
| return t |
|
|
|
|