""" Noising schedule for SAD (Soft-Ancestor Skip Diffusion). We implement a practical discrete-time level-mixture noising process: - Sample t ~ Uniform[eps, 1-eps] - Convert t to categorical level weights rho_0(t), ..., rho_L(t) - For each token independently, sample a corruption level h ~ rho(t) - Build noisy input using the state representation at level h Two schedules are provided: 1. `CategoricalLevelSchedule`: rho_l(t) is a piecewise linear function of t. rho_0(t) = 1-t (stay clean), rho_l(t) proportional to max(0, t - (l-1)/L) - max(0, t - l/L) for l=1..L This is a uniform "smear" across levels as t increases. Easy to implement and works well in practice. 2. `AdjacentLevelSchedule`: heavier mass on adjacent-level transitions. More principled; TODO: connect to exact CTMC in a future version. APPROXIMATION NOTE: The exact CTMC ELBO for a continuous-time multi-level process requires computing per-token transition rates between non-adjacent states, which is mathematically involved for learned soft ancestors. The schedules below are practical discrete-time surrogates. Hooks for a stricter CTMC ELBO are marked with # TODO(CTMC). """ from abc import ABC, abstractmethod from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F class LevelNoiseSchedule(nn.Module, ABC): """ Base class for level-mixture noise schedules. Level numbering convention: level 0 : clean leaf tokens (vocabulary) level 1 .. L-1 : intermediate prototype levels (from level_sizes[1..L-1]) level L (= num_levels) : MASK state (implicit; uses mask_token_id) So `num_levels` = len(level_sizes) counts only the real vocabulary levels. The mask state lives one step *above* the coarsest prototype level and is NOT a prototype — it is represented by mask_token_id directly. `total_states` = num_levels + 1 (includes the mask level). """ def __init__(self, num_levels: int, vocab_size: int, mask_token_id: int): """ Args: num_levels: number of real vocabulary levels (leaf + intermediate prototypes). level_sizes has length num_levels. Mask is the implicit (num_levels)-th level. vocab_size: |V|, leaf vocabulary size. mask_token_id: id of the [MASK] token. """ super().__init__() self.num_levels = num_levels # number of real vocab levels (NOT counting mask) self.total_states = num_levels + 1 # num_levels real levels + 1 mask level self.vocab_size = vocab_size self.mask_token_id = mask_token_id @abstractmethod def level_weights(self, t: torch.Tensor) -> torch.Tensor: """ Return per-level probability weights rho(t). Args: t: [B] timesteps in [0, 1] Returns: rho: [B, total_states] where total_states = num_levels + 1. rho[:, 0] = weight for level 0 (clean leaf) rho[:, 1..L-1] = weight for intermediate prototype levels rho[:, num_levels] = weight for MASK level """ raise NotImplementedError def sample_levels(self, t: torch.Tensor, seq_len: int) -> torch.Tensor: """ Sample independent corruption levels for each (batch, position). Returns levels in [0, total_states-1]. Level num_levels means MASK. Args: t: [B] seq_len: S Returns: levels: [B, S] int64, values in [0, total_states-1] value == num_levels means MASK """ rho = self.level_weights(t) # [B, total_states] rho_expanded = rho.unsqueeze(1).expand(-1, seq_len, -1) # [B, S, total_states] levels = torch.multinomial( rho_expanded.reshape(-1, self.total_states), 1 ).reshape(-1, seq_len) # [B, S] return levels class CategoricalLevelSchedule(LevelNoiseSchedule): """ Simple piecewise-linear level-mixture schedule. States: 0=leaf, 1..num_levels-1=intermediate prototypes, num_levels=MASK total_states = num_levels + 1 At t=0: all weight on state 0 (clean leaf). At t=1: all weight on state num_levels (MASK). Intermediate states peak in between with triangular tent functions. """ def level_weights(self, t: torch.Tensor) -> torch.Tensor: # t: [B] B = t.shape[0] N = self.total_states # num_levels + 1 device = t.device dtype = t.dtype # Band boundaries: [0, 1/N, 2/N, ..., 1] (N+1 boundaries, N bands) boundaries = torch.linspace(0.0, 1.0, N + 1, device=device, dtype=dtype) rho = torch.zeros(B, N, device=device, dtype=dtype) # State 0: rho_0 = max(0, 1 - t / b1) rho[:, 0] = (1.0 - t / boundaries[1].clamp(min=1e-8)).clamp(0.0, 1.0) # Intermediate states 1..N-2: triangular tent for l in range(1, N - 1): lo_l = boundaries[l] hi_l = boundaries[l + 1] w_l = (hi_l - lo_l).clamp(min=1e-8) rise = ((t - lo_l) / w_l).clamp(0.0, 1.0) fall = ((t - hi_l) / w_l).clamp(max=0.0).abs() rho[:, l] = (rise - fall).clamp(0.0) # Last state N-1 (MASK): rho_mask = max(0, (t - b_{N-1}) / w_{N-1}) lo_last = boundaries[N - 1] w_last = (boundaries[N] - lo_last).clamp(min=1e-8) rho[:, -1] = ((t - lo_last) / w_last).clamp(0.0, 1.0) # Normalize rho = rho / rho.sum(dim=1, keepdim=True).clamp(min=1e-8) return rho # [B, total_states] class AdjacentLevelSchedule(LevelNoiseSchedule): """ Adjacent-level schedule: at each t, mass splits between two adjacent states. States: 0=leaf .. num_levels-1=coarsest prototype, num_levels=MASK total_states = num_levels + 1 mu(t) = t * num_levels maps t in [0,1] to mean state in [0, num_levels]. # TODO(CTMC): Connect to exact continuous-time rates. """ def level_weights(self, t: torch.Tensor) -> torch.Tensor: B = t.shape[0] N = self.total_states # num_levels + 1 device = t.device dtype = t.dtype mu = t * (N - 1) # [B] in [0, N-1] lo = mu.floor().long().clamp(0, N - 2) hi = (lo + 1).clamp(0, N - 1) frac_hi = (mu - lo.float()).clamp(0.0, 1.0) frac_lo = 1.0 - frac_hi rho = torch.zeros(B, N, device=device, dtype=dtype) rho.scatter_(1, lo.unsqueeze(1), frac_lo.unsqueeze(1)) rho.scatter_add_(1, hi.unsqueeze(1), frac_hi.unsqueeze(1)) return rho # [B, total_states] def sample_t(batch_size: int, eps: float = 1e-4, low_discrepancy: bool = False, device=None) -> torch.Tensor: """Sample timesteps t ~ Uniform[eps, 1-eps].""" if low_discrepancy: t = torch.arange(batch_size, device=device).float() / batch_size t = (t + torch.rand(1, device=device)).fmod(1.0) else: t = torch.rand(batch_size, device=device) t = (1 - 2 * eps) * t + eps return t def get_schedule(schedule_type: str, num_levels: int, vocab_size: int, mask_token_id: int) -> LevelNoiseSchedule: if schedule_type == "categorical": return CategoricalLevelSchedule(num_levels, vocab_size, mask_token_id) elif schedule_type == "adjacent": return AdjacentLevelSchedule(num_levels, vocab_size, mask_token_id) else: raise ValueError(f"Unknown schedule type: {schedule_type}")