| """ |
| Noising schedule for SAD (Soft-Ancestor Skip Diffusion). |
| |
| We implement a practical discrete-time level-mixture noising process: |
| - Sample t ~ Uniform[eps, 1-eps] |
| - Convert t to categorical level weights rho_0(t), ..., rho_L(t) |
| - For each token independently, sample a corruption level h ~ rho(t) |
| - Build noisy input using the state representation at level h |
| |
| Two schedules are provided: |
| 1. `CategoricalLevelSchedule`: rho_l(t) is a piecewise linear function of t. |
| rho_0(t) = 1-t (stay clean), |
| rho_l(t) proportional to max(0, t - (l-1)/L) - max(0, t - l/L) for l=1..L |
| This is a uniform "smear" across levels as t increases. |
| Easy to implement and works well in practice. |
| |
| 2. `AdjacentLevelSchedule`: heavier mass on adjacent-level transitions. |
| More principled; TODO: connect to exact CTMC in a future version. |
| |
| APPROXIMATION NOTE: |
| The exact CTMC ELBO for a continuous-time multi-level process requires computing |
| per-token transition rates between non-adjacent states, which is mathematically |
| involved for learned soft ancestors. The schedules below are practical discrete-time |
| surrogates. Hooks for a stricter CTMC ELBO are marked with # TODO(CTMC). |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| class LevelNoiseSchedule(nn.Module, ABC): |
| """ |
| Base class for level-mixture noise schedules. |
| |
| Level numbering convention: |
| level 0 : clean leaf tokens (vocabulary) |
| level 1 .. L-1 : intermediate prototype levels (from level_sizes[1..L-1]) |
| level L (= num_levels) : MASK state (implicit; uses mask_token_id) |
| |
| So `num_levels` = len(level_sizes) counts only the real vocabulary levels. |
| The mask state lives one step *above* the coarsest prototype level and is |
| NOT a prototype — it is represented by mask_token_id directly. |
| |
| `total_states` = num_levels + 1 (includes the mask level). |
| """ |
|
|
| def __init__(self, num_levels: int, vocab_size: int, mask_token_id: int): |
| """ |
| Args: |
| num_levels: number of real vocabulary levels (leaf + intermediate prototypes). |
| level_sizes has length num_levels. |
| Mask is the implicit (num_levels)-th level. |
| vocab_size: |V|, leaf vocabulary size. |
| mask_token_id: id of the [MASK] token. |
| """ |
| super().__init__() |
| self.num_levels = num_levels |
| self.total_states = num_levels + 1 |
| self.vocab_size = vocab_size |
| self.mask_token_id = mask_token_id |
|
|
| @abstractmethod |
| def level_weights(self, t: torch.Tensor) -> torch.Tensor: |
| """ |
| Return per-level probability weights rho(t). |
| |
| Args: |
| t: [B] timesteps in [0, 1] |
| |
| Returns: |
| rho: [B, total_states] where total_states = num_levels + 1. |
| rho[:, 0] = weight for level 0 (clean leaf) |
| rho[:, 1..L-1] = weight for intermediate prototype levels |
| rho[:, num_levels] = weight for MASK level |
| """ |
| raise NotImplementedError |
|
|
| def sample_levels(self, t: torch.Tensor, seq_len: int) -> torch.Tensor: |
| """ |
| Sample independent corruption levels for each (batch, position). |
| |
| Returns levels in [0, total_states-1]. |
| Level num_levels means MASK. |
| |
| Args: |
| t: [B] |
| seq_len: S |
| |
| Returns: |
| levels: [B, S] int64, values in [0, total_states-1] |
| value == num_levels means MASK |
| """ |
| rho = self.level_weights(t) |
| rho_expanded = rho.unsqueeze(1).expand(-1, seq_len, -1) |
| levels = torch.multinomial( |
| rho_expanded.reshape(-1, self.total_states), 1 |
| ).reshape(-1, seq_len) |
| return levels |
|
|
|
|
| class CategoricalLevelSchedule(LevelNoiseSchedule): |
| """ |
| Simple piecewise-linear level-mixture schedule. |
| |
| States: 0=leaf, 1..num_levels-1=intermediate prototypes, num_levels=MASK |
| total_states = num_levels + 1 |
| |
| At t=0: all weight on state 0 (clean leaf). |
| At t=1: all weight on state num_levels (MASK). |
| Intermediate states peak in between with triangular tent functions. |
| """ |
|
|
| def level_weights(self, t: torch.Tensor) -> torch.Tensor: |
| |
| B = t.shape[0] |
| N = self.total_states |
| device = t.device |
| dtype = t.dtype |
|
|
| |
| boundaries = torch.linspace(0.0, 1.0, N + 1, device=device, dtype=dtype) |
|
|
| rho = torch.zeros(B, N, device=device, dtype=dtype) |
|
|
| |
| rho[:, 0] = (1.0 - t / boundaries[1].clamp(min=1e-8)).clamp(0.0, 1.0) |
|
|
| |
| for l in range(1, N - 1): |
| lo_l = boundaries[l] |
| hi_l = boundaries[l + 1] |
| w_l = (hi_l - lo_l).clamp(min=1e-8) |
| rise = ((t - lo_l) / w_l).clamp(0.0, 1.0) |
| fall = ((t - hi_l) / w_l).clamp(max=0.0).abs() |
| rho[:, l] = (rise - fall).clamp(0.0) |
|
|
| |
| lo_last = boundaries[N - 1] |
| w_last = (boundaries[N] - lo_last).clamp(min=1e-8) |
| rho[:, -1] = ((t - lo_last) / w_last).clamp(0.0, 1.0) |
|
|
| |
| rho = rho / rho.sum(dim=1, keepdim=True).clamp(min=1e-8) |
| return rho |
|
|
|
|
| class AdjacentLevelSchedule(LevelNoiseSchedule): |
| """ |
| Adjacent-level schedule: at each t, mass splits between two adjacent states. |
| |
| States: 0=leaf .. num_levels-1=coarsest prototype, num_levels=MASK |
| total_states = num_levels + 1 |
| |
| mu(t) = t * num_levels maps t in [0,1] to mean state in [0, num_levels]. |
| |
| # TODO(CTMC): Connect to exact continuous-time rates. |
| """ |
|
|
| def level_weights(self, t: torch.Tensor) -> torch.Tensor: |
| B = t.shape[0] |
| N = self.total_states |
| device = t.device |
| dtype = t.dtype |
|
|
| mu = t * (N - 1) |
| lo = mu.floor().long().clamp(0, N - 2) |
| hi = (lo + 1).clamp(0, N - 1) |
| frac_hi = (mu - lo.float()).clamp(0.0, 1.0) |
| frac_lo = 1.0 - frac_hi |
|
|
| rho = torch.zeros(B, N, device=device, dtype=dtype) |
| rho.scatter_(1, lo.unsqueeze(1), frac_lo.unsqueeze(1)) |
| rho.scatter_add_(1, hi.unsqueeze(1), frac_hi.unsqueeze(1)) |
| return rho |
|
|
|
|
| def sample_t(batch_size: int, eps: float = 1e-4, |
| low_discrepancy: bool = False, |
| device=None) -> torch.Tensor: |
| """Sample timesteps t ~ Uniform[eps, 1-eps].""" |
| if low_discrepancy: |
| t = torch.arange(batch_size, device=device).float() / batch_size |
| t = (t + torch.rand(1, device=device)).fmod(1.0) |
| else: |
| t = torch.rand(batch_size, device=device) |
| t = (1 - 2 * eps) * t + eps |
| return t |
|
|
|
|
| def get_schedule(schedule_type: str, num_levels: int, |
| vocab_size: int, mask_token_id: int) -> LevelNoiseSchedule: |
| if schedule_type == "categorical": |
| return CategoricalLevelSchedule(num_levels, vocab_size, mask_token_id) |
| elif schedule_type == "adjacent": |
| return AdjacentLevelSchedule(num_levels, vocab_size, mask_token_id) |
| else: |
| raise ValueError(f"Unknown schedule type: {schedule_type}") |
|
|