sad / src /diffusion /noisy_state.py
haochengsama's picture
Add files using upload-large-folder tool
922bb4b verified
Raw
History Blame Contribute Delete
9.4 kB
"""
Noisy state builder for block-AR SAD training.
Noising levels per token (uniformly sampled):
0 = clean (keep leaf embedding)
1..L-1 = ancestor level l (sample from LUT, use learnable ancestor embedding)
L = mask token
AncestorTable provides:
- Fixed LUT (indices + probs): which ancestor cluster each token maps to
- Learnable ancestor embeddings: the actual embedding used as noisy input
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from src.diffusion.ancestor_table import AncestorTable
class NoisyStateBuilder(nn.Module):
"""
Builds noisy embeddings for vectorized block-AR training.
Args:
vocab_size: V
mask_token_id: id of [MASK] in the leaf vocabulary
temp: kept for API compat, unused (temp is baked into LUT probs)
top_k_per_level: kept for API compat, unused
"""
def __init__(
self,
vocab_size: int,
mask_token_id: int,
temp: float = 1.0,
top_k_per_level: Optional[List[int]] = None,
use_soft_expected: bool = True,
):
super().__init__()
self.vocab_size = vocab_size
self.mask_token_id = mask_token_id
# temp / top_k_per_level kept for API compat but not used
def sample_levels_uniform(
self, B: int, S: int, num_total_states: int, device: torch.device
) -> torch.Tensor:
"""
Sample per-token levels uniformly from {0, 1, ..., num_total_states-1}.
num_total_states = num_ancestor_levels + 1 + 1
= (mask_level) where mask_level = num_ancestor_levels + 1
Returns:
levels: [B, S] int64
"""
return torch.randint(0, num_total_states, (B, S), device=device)
@staticmethod
def sample_t(
B: int,
device: torch.device,
eps: float = 1e-3,
low_discrepancy: bool = False,
rank: int = 0,
world_size: int = 1,
) -> torch.Tensor:
"""
Sample per-sequence noise level t ~ U[eps, 1-eps] (shape [B]).
Change 3 (MDLM): low-discrepancy stratified sampling. A single shared
phase u is drawn per step; the batch covers the grid {u, u+1/B, ...}.
In DDP, the phase is offset by rank/(world_size*B) so that all ranks
together cover finer strata (analogous to Kingma's VDM).
Change 2 (BD3-LMs / Soft-Masked Diffusion): clamp to [eps, 1-eps] to
avoid high-variance ELBO gradients at extreme mask rates (t→0 or t→1).
"""
if low_discrepancy:
# Fresh shared phase every step — must not be cached across steps.
u = torch.rand((), device=device)
if world_size > 1:
# Offset phase per rank so disjoint strata across the global batch.
u = (u + rank / (world_size * B)) % 1.0
t = torch.arange(B, device=device, dtype=torch.float32) / B
t = (t + u) % 1.0
else:
t = torch.rand(B, device=device)
# Clamp to [eps, 1-eps]: avoid high-variance ELBO at extreme mask rates.
return (1 - 2 * eps) * t + eps
def sample_levels_hdlm(
self,
t: torch.Tensor,
S: int,
num_ancestor_levels: int,
gamma: float = 1.0,
) -> torch.Tensor:
"""
HDLM 3-state (clean / ancestor / mask) hierarchical schedule,
γ=1 form from HierarchicalDiffusion.get_alpha_betapi:
α_t = 1 - t # P(clean, keep original)
c_t = -(1-t) * log(1-t) # P(ancestor state)
m_t = t + (1-t) * log(1-t) # P(mask)
The three probabilities sum to 1 (with tiny numerical renormalization).
Each sequence has its own t (shape [B]); each token within the sequence
is drawn i.i.d. from the 3-way categorical. If the token lands on the
ancestor state and num_ancestor_levels > 1, c_t is split evenly among
levels 1..L.
Args:
t: [B] in (0, 1), from sample_t()
S: sequence length
num_ancestor_levels: L = ancestor_table.num_levels (typically 1)
gamma: only γ=1 supported
Returns:
levels: [B, S] int64, values in {0, 1..L, L+1=mask_level}
"""
if gamma != 1.0:
raise NotImplementedError("sample_levels_hdlm only supports γ=1")
assert num_ancestor_levels >= 1
B = t.shape[0]
device = t.device
eps = 1e-8
one_m_t = (1.0 - t).clamp(min=eps) # [B]
log_1m_t = one_m_t.log() # [B] <= 0
alpha_t = one_m_t # P(clean)
c_t = -one_m_t * log_1m_t # P(ancestor-any) >= 0
m_t = (t + one_m_t * log_1m_t).clamp(min=0.0) # P(mask) >= 0
# Renormalize for safety (should already sum to ~1)
total = alpha_t + c_t + m_t
alpha_t = alpha_t / total
c_t = c_t / total
m_t = m_t / total
# Split c_t evenly among ancestor levels 1..L
c_per_level = c_t / num_ancestor_levels # [B]
# Build per-sample probability over categories [clean, anc_1..anc_L, mask]
probs = torch.stack(
[alpha_t] + [c_per_level] * num_ancestor_levels + [m_t],
dim=-1,
) # [B, 2+L]
# Per-token multinomial. Category index i maps 1-to-1 to level value:
# 0 -> clean (level 0)
# 1..L -> ancestor levels 1..L
# L+1 -> mask_level (= num_ancestor_levels + 1 in NoisyStateBuilder)
probs_flat = (
probs.unsqueeze(1).expand(B, S, -1).reshape(B * S, -1)
)
sampled = torch.multinomial(probs_flat, num_samples=1).squeeze(-1)
return sampled.view(B, S).long()
def build_noisy_embeddings(
self,
input_ids: torch.Tensor,
levels: torch.Tensor,
ancestor_table: AncestorTable,
leaf_embeddings: torch.Tensor,
mask_embedding: torch.Tensor,
# kept for API compat (ignored):
hierarchy=None,
) -> Tuple[torch.Tensor, torch.Tensor, List[Optional[torch.Tensor]], torch.Tensor]:
"""
Build noisy embeddings from LUT-based ancestor table.
For each position:
level 0: leaf_embeddings[token_id] (clean)
level 1..L: learnable ancestor embedding sampled via LUT
level L+1: mask_embedding (fully masked)
Args:
input_ids: [B, S] int64
levels: [B, S] int64 (0 .. num_ancestor_levels+1)
ancestor_table: AncestorTable
leaf_embeddings: [V, d]
mask_embedding: [d]
hierarchy: ignored (kept for API compat)
Returns:
noisy_embs: [B, S, d]
ancestor_log_probs: [B, S] log-prob of chosen ancestor (0 at clean/mask)
ancestor_probs_per_lvl: list of None (not needed downstream)
corrupt_mask: [B, S] bool – True at positions level >= 1
"""
B, S = input_ids.shape
dtype = leaf_embeddings.dtype
device = input_ids.device
# mask level = num_ancestor_levels + 1
mask_level = ancestor_table.num_levels + 1
# Start with clean embeddings everywhere
noisy_embs = leaf_embeddings[input_ids].clone() # [B, S, d]
ancestor_log_probs = torch.zeros(B, S, device=device, dtype=dtype)
corrupt_mask = torch.zeros(B, S, dtype=torch.bool, device=device)
ancestor_probs_per_lvl: List[Optional[torch.Tensor]] = []
# Apply mask
mask_pos = (levels == mask_level)
if mask_pos.any():
noisy_embs[mask_pos] = mask_embedding.to(dtype)
corrupt_mask[mask_pos] = True
for l in range(1, mask_level): # ancestor levels 1..mask_level-1
pos_l = (levels == l)
if not pos_l.any():
ancestor_probs_per_lvl.append(None)
continue
flat_ids = input_ids[pos_l] # [N]
N = flat_ids.shape[0]
# Sample ancestor index via LUT multinomial
lut_idx = ancestor_table.lut_indices(l)[flat_ids] # [N, top_k]
lut_prob = ancestor_table.lut_probs(l)[flat_ids] # [N, top_k]
sampled_local = torch.multinomial(lut_prob, num_samples=1).squeeze(1) # [N]
sampled_global = lut_idx[
torch.arange(N, device=device), sampled_local
] # [N]
# Learnable ancestor embedding
anc_emb = ancestor_table.ancestor_embeddings(l) # [K, d]
noisy_embs[pos_l] = anc_emb[sampled_global].to(dtype)
# Log-prob of chosen ancestor (from LUT probs)
chosen_lp = lut_prob[
torch.arange(N, device=device), sampled_local
].clamp(min=1e-8).log()
ancestor_log_probs[pos_l] = chosen_lp.to(dtype)
corrupt_mask[pos_l] = True
ancestor_probs_per_lvl.append(None) # not needed downstream
return noisy_embs, ancestor_log_probs, ancestor_probs_per_lvl, corrupt_mask