sad / src /diffusion /__init__.py
haochengsama's picture
add missing files batch 13 (400)
278b5e7 verified
Raw
History Blame Contribute Delete
7.7 kB
"""
Noising schedule for SAD (Soft-Ancestor Skip Diffusion).
We implement a practical discrete-time level-mixture noising process:
- Sample t ~ Uniform[eps, 1-eps]
- Convert t to categorical level weights rho_0(t), ..., rho_L(t)
- For each token independently, sample a corruption level h ~ rho(t)
- Build noisy input using the state representation at level h
Two schedules are provided:
1. `CategoricalLevelSchedule`: rho_l(t) is a piecewise linear function of t.
rho_0(t) = 1-t (stay clean),
rho_l(t) proportional to max(0, t - (l-1)/L) - max(0, t - l/L) for l=1..L
This is a uniform "smear" across levels as t increases.
Easy to implement and works well in practice.
2. `AdjacentLevelSchedule`: heavier mass on adjacent-level transitions.
More principled; TODO: connect to exact CTMC in a future version.
APPROXIMATION NOTE:
The exact CTMC ELBO for a continuous-time multi-level process requires computing
per-token transition rates between non-adjacent states, which is mathematically
involved for learned soft ancestors. The schedules below are practical discrete-time
surrogates. Hooks for a stricter CTMC ELBO are marked with # TODO(CTMC).
"""
from abc import ABC, abstractmethod
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class LevelNoiseSchedule(nn.Module, ABC):
"""
Base class for level-mixture noise schedules.
Level numbering convention:
level 0 : clean leaf tokens (vocabulary)
level 1 .. L-1 : intermediate prototype levels (from level_sizes[1..L-1])
level L (= num_levels) : MASK state (implicit; uses mask_token_id)
So `num_levels` = len(level_sizes) counts only the real vocabulary levels.
The mask state lives one step *above* the coarsest prototype level and is
NOT a prototype — it is represented by mask_token_id directly.
`total_states` = num_levels + 1 (includes the mask level).
"""
def __init__(self, num_levels: int, vocab_size: int, mask_token_id: int):
"""
Args:
num_levels: number of real vocabulary levels (leaf + intermediate prototypes).
level_sizes has length num_levels.
Mask is the implicit (num_levels)-th level.
vocab_size: |V|, leaf vocabulary size.
mask_token_id: id of the [MASK] token.
"""
super().__init__()
self.num_levels = num_levels # number of real vocab levels (NOT counting mask)
self.total_states = num_levels + 1 # num_levels real levels + 1 mask level
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
@abstractmethod
def level_weights(self, t: torch.Tensor) -> torch.Tensor:
"""
Return per-level probability weights rho(t).
Args:
t: [B] timesteps in [0, 1]
Returns:
rho: [B, total_states] where total_states = num_levels + 1.
rho[:, 0] = weight for level 0 (clean leaf)
rho[:, 1..L-1] = weight for intermediate prototype levels
rho[:, num_levels] = weight for MASK level
"""
raise NotImplementedError
def sample_levels(self, t: torch.Tensor, seq_len: int) -> torch.Tensor:
"""
Sample independent corruption levels for each (batch, position).
Returns levels in [0, total_states-1].
Level num_levels means MASK.
Args:
t: [B]
seq_len: S
Returns:
levels: [B, S] int64, values in [0, total_states-1]
value == num_levels means MASK
"""
rho = self.level_weights(t) # [B, total_states]
rho_expanded = rho.unsqueeze(1).expand(-1, seq_len, -1) # [B, S, total_states]
levels = torch.multinomial(
rho_expanded.reshape(-1, self.total_states), 1
).reshape(-1, seq_len) # [B, S]
return levels
class CategoricalLevelSchedule(LevelNoiseSchedule):
"""
Simple piecewise-linear level-mixture schedule.
States: 0=leaf, 1..num_levels-1=intermediate prototypes, num_levels=MASK
total_states = num_levels + 1
At t=0: all weight on state 0 (clean leaf).
At t=1: all weight on state num_levels (MASK).
Intermediate states peak in between with triangular tent functions.
"""
def level_weights(self, t: torch.Tensor) -> torch.Tensor:
# t: [B]
B = t.shape[0]
N = self.total_states # num_levels + 1
device = t.device
dtype = t.dtype
# Band boundaries: [0, 1/N, 2/N, ..., 1] (N+1 boundaries, N bands)
boundaries = torch.linspace(0.0, 1.0, N + 1, device=device, dtype=dtype)
rho = torch.zeros(B, N, device=device, dtype=dtype)
# State 0: rho_0 = max(0, 1 - t / b1)
rho[:, 0] = (1.0 - t / boundaries[1].clamp(min=1e-8)).clamp(0.0, 1.0)
# Intermediate states 1..N-2: triangular tent
for l in range(1, N - 1):
lo_l = boundaries[l]
hi_l = boundaries[l + 1]
w_l = (hi_l - lo_l).clamp(min=1e-8)
rise = ((t - lo_l) / w_l).clamp(0.0, 1.0)
fall = ((t - hi_l) / w_l).clamp(max=0.0).abs()
rho[:, l] = (rise - fall).clamp(0.0)
# Last state N-1 (MASK): rho_mask = max(0, (t - b_{N-1}) / w_{N-1})
lo_last = boundaries[N - 1]
w_last = (boundaries[N] - lo_last).clamp(min=1e-8)
rho[:, -1] = ((t - lo_last) / w_last).clamp(0.0, 1.0)
# Normalize
rho = rho / rho.sum(dim=1, keepdim=True).clamp(min=1e-8)
return rho # [B, total_states]
class AdjacentLevelSchedule(LevelNoiseSchedule):
"""
Adjacent-level schedule: at each t, mass splits between two adjacent states.
States: 0=leaf .. num_levels-1=coarsest prototype, num_levels=MASK
total_states = num_levels + 1
mu(t) = t * num_levels maps t in [0,1] to mean state in [0, num_levels].
# TODO(CTMC): Connect to exact continuous-time rates.
"""
def level_weights(self, t: torch.Tensor) -> torch.Tensor:
B = t.shape[0]
N = self.total_states # num_levels + 1
device = t.device
dtype = t.dtype
mu = t * (N - 1) # [B] in [0, N-1]
lo = mu.floor().long().clamp(0, N - 2)
hi = (lo + 1).clamp(0, N - 1)
frac_hi = (mu - lo.float()).clamp(0.0, 1.0)
frac_lo = 1.0 - frac_hi
rho = torch.zeros(B, N, device=device, dtype=dtype)
rho.scatter_(1, lo.unsqueeze(1), frac_lo.unsqueeze(1))
rho.scatter_add_(1, hi.unsqueeze(1), frac_hi.unsqueeze(1))
return rho # [B, total_states]
def sample_t(batch_size: int, eps: float = 1e-4,
low_discrepancy: bool = False,
device=None) -> torch.Tensor:
"""Sample timesteps t ~ Uniform[eps, 1-eps]."""
if low_discrepancy:
t = torch.arange(batch_size, device=device).float() / batch_size
t = (t + torch.rand(1, device=device)).fmod(1.0)
else:
t = torch.rand(batch_size, device=device)
t = (1 - 2 * eps) * t + eps
return t
def get_schedule(schedule_type: str, num_levels: int,
vocab_size: int, mask_token_id: int) -> LevelNoiseSchedule:
if schedule_type == "categorical":
return CategoricalLevelSchedule(num_levels, vocab_size, mask_token_id)
elif schedule_type == "adjacent":
return AdjacentLevelSchedule(num_levels, vocab_size, mask_token_id)
else:
raise ValueError(f"Unknown schedule type: {schedule_type}")