baguette / src /schedules.py
nbagel's picture
Initial upload: Paris MoE inference code and weights
4dec1ca verified
# src/schedules.py
"""
Centralized noise schedule manager for diffusion models.
Supports three schedules:
1. 'cosine': Cosine schedule (Nichol & Dhariwal 2021)
2. 'linear_beta': Linear beta schedule (Ho et al. 2020)
3. 'linear_interp': Linear interpolation - Flow Matching default
All schedules return (alpha_t, sigma_t) such that:
x_t = alpha_t * x_0 + sigma_t * epsilon
alpha_t^2 + sigma_t^2 = 1 (variance preserving)
"""
import torch
import math
from typing import Tuple
class NoiseSchedule:
"""
Centralized noise schedule manager.
Args:
schedule_type: One of ['cosine', 'linear_beta', 'linear_interp']
"""
def __init__(self, schedule_type: str = 'linear_interp'):
assert schedule_type in ['cosine', 'linear_beta', 'linear_interp'], \
f"Unknown schedule: {schedule_type}. Must be one of ['cosine', 'linear_beta', 'linear_interp']"
self.schedule_type = schedule_type
# Linear beta schedule parameters (if used)
self.beta_min = 0.0001
self.beta_max = 0.02
self.num_timesteps = 1000 # T in discrete formulation
# Cosine schedule parameter
self.s = 0.008 # Small offset to prevent beta from being too small near t=0
def get_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get (alpha_t, sigma_t) for given timesteps.
Args:
t: Tensor of timesteps in [0, 1], shape (B,)
Returns:
alpha_t: Shape (B,), coefficient for x_0
sigma_t: Shape (B,), coefficient for epsilon
"""
if self.schedule_type == 'cosine':
return self._cosine_schedule(t)
elif self.schedule_type == 'linear_beta':
return self._linear_beta_schedule(t)
elif self.schedule_type == 'linear_interp':
return self._linear_interpolation(t)
def _cosine_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Cosine schedule: alpha_bar_t = f(t) / f(0)
where f(t) = cos²((t + s)/(1 + s) * π/2)
Reference: "Improved Denoising Diffusion Probabilistic Models"
(Nichol & Dhariwal, 2021)
This schedule provides better conditioning than linear beta schedule,
especially at very small and very large t values.
"""
# Compute f(t) = cos²((t + s)/(1 + s) * π/2)
f_t = torch.cos(((t + self.s) / (1 + self.s)) * math.pi * 0.5) ** 2
# Compute f(0) for normalization to ensure alpha_bar_0 = 1
f_0 = math.cos((self.s / (1 + self.s)) * math.pi * 0.5) ** 2
# Normalize: alpha_bar_t = f(t) / f(0)
alpha_bar_t = f_t / f_0
# Clamp to ensure numerical stability
alpha_bar_t = torch.clamp(alpha_bar_t, min=1e-8, max=1.0)
# Compute coefficients
alpha_t = torch.sqrt(alpha_bar_t)
sigma_t = torch.sqrt(1 - alpha_bar_t)
return alpha_t, sigma_t
def _linear_beta_schedule(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Linear beta schedule: beta_t increases linearly from beta_min to beta_max
Reference: "Denoising Diffusion Probabilistic Models" (Ho et al., 2020)
For continuous time t ∈ [0,1]:
beta(t) = beta_min + t * (beta_max - beta_min)
alpha_bar(t) = exp(-0.5 * integral_0^t beta(s) ds)
= exp(-0.5 * t * (beta_min + 0.5 * t * (beta_max - beta_min)))
"""
# Compute alpha_bar(t) = exp(-0.5 * integral beta(s) ds)
# integral_0^t (beta_min + s * (beta_max - beta_min)) ds
# = beta_min * t + 0.5 * t^2 * (beta_max - beta_min)
integral_beta = self.beta_min * t + 0.5 * t * t * (self.beta_max - self.beta_min)
alpha_bar_t = torch.exp(-0.5 * integral_beta * self.num_timesteps)
# Compute coefficients
alpha_t = torch.sqrt(alpha_bar_t)
sigma_t = torch.sqrt(1 - alpha_bar_t)
return alpha_t, sigma_t
def _linear_interpolation(self, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Linear interpolation: x_t = (1-t) * x_0 + t * epsilon
This is the default for Flow Matching but NOT a proper DDPM schedule.
This is what the current implementation uses.
"""
alpha_t = 1 - t
sigma_t = t
return alpha_t, sigma_t
def get_snr(self, t: torch.Tensor) -> torch.Tensor:
"""
Compute signal-to-noise ratio (SNR) = alpha_t^2 / sigma_t^2
Useful for:
1. Time warping between different schedules
2. Analysis and visualization
Args:
t: Tensor of timesteps in [0, 1]
Returns:
snr: Signal-to-noise ratio at each timestep
"""
alpha_t, sigma_t = self.get_schedule(t)
snr = (alpha_t ** 2) / (sigma_t ** 2 + 1e-8)
return snr
def alpha_to_time(self, alpha: torch.Tensor, num_steps: int = 100) -> torch.Tensor:
"""
Inverse mapping: given alpha, find t
Used for inference when you want to specify noise levels directly.
Uses binary search since schedules are monotonic.
Args:
alpha: Desired alpha values
num_steps: Number of steps for binary search
Returns:
t: Corresponding timesteps
"""
device = alpha.device
# Binary search for t
t_candidates = torch.linspace(0, 1, num_steps, device=device)
alpha_candidates, _ = self.get_schedule(t_candidates)
# Find closest match
distances = torch.abs(alpha_candidates.unsqueeze(0) - alpha.unsqueeze(1))
indices = torch.argmin(distances, dim=1)
t = t_candidates[indices]
return t