File size: 10,855 Bytes
c383594 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | """MDLM Rao-Blackwellized Masked Diffusion Loss.
Implements the masked-diffusion ELBO from:
Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
NeurIPS 2024, arXiv:2406.07524.
Equations referenced:
- Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t)
- Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1)
- RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (alpha'_t / (1 - alpha_t)) *
CE(x_theta(x_t), x_0) ] where the expectation is
over masked positions. For alpha_t = 1 - t, the
magnitude is proportional to 1 / t, i.e. inverse
mask probability, not inverse keep probability.
Key insight: the Rao-Blackwellized estimate replaces an average over all masks
(exponential) by a closed-form weighted CE that applies inverse mask-probability
weight only on the positions that were masked, and 0 on unmasked positions. This
gives an unbiased estimator with lower variance than a naive Monte Carlo over
mask patterns.
Reference implementation cross-checked against:
https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss)
"""
from __future__ import annotations
from typing import Literal
import torch
import torch.nn.functional as F
# Clamping weight keeps gradients finite while still up-weighting high-noise
# positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
# launch (2026-04-22): loss 26 β 42 β NaN in 13 steps under Muon lr=7e-3
# because per-token CE Γ 1000 saturated the 100-unit FAIL guard. The MDLM
# paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
# (70Γ larger), so the weight clamp needs to compensate.
#
# Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
# weighting entirely (flat masked-LM CE, no RB reweighting β simpler and
# more stable, sacrifices the theoretical ELBO property).
import os as _os
_MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
_MIN_MASK_PROB: float = 1.0 / _MAX_WEIGHT # so clamp(mask_prob, min=...) gives 1/mask_prob <= _MAX_WEIGHT
# Back-compat export for older tests/scripts that imported _MIN_ALPHA. The
# minimum now applies to mask probability t = 1 - alpha_t, not alpha_t itself.
_MIN_ALPHA: float = _MIN_MASK_PROB
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def mdlm_masked_forward_process(
targets: torch.Tensor,
mask_token_id: int,
t: torch.Tensor | None = None,
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""MDLM forward (noising) process: mask tokens and compute RB weights.
Args:
targets: (B, T) int64 token ids β the clean sequence x_0.
mask_token_id: The special token id used to represent a masked token.
t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
element. t=0 means fully clean; t=1 means fully masked.
alpha_schedule: Noise schedule.
"loglinear" (MDLM default): alpha_t = 1 - t
"linear": identical formula β both are provided for completeness
since the paper calls the 1-t schedule "log-linear" in the context
of the ELBO derivation.
Returns:
x_t : (B, T) int64 β noised sequence; masked positions hold
mask_token_id, unmasked positions equal targets.
mask_positions: (B, T) bool β True where the token was masked.
loss_weights : (B, T) float32 β RB weighting factor. On masked
positions: 1/(1-alpha_t), i.e. 1/mask_prob (clamped to
_MAX_WEIGHT). On
unmasked positions: 0.0. Summing
(CE * loss_weights * mask_positions).sum() / mask.sum()
gives the per-sample RB-ELBO estimator.
"""
B, T = targets.shape
device = targets.device
dtype = torch.float32
# --- sample or validate t ---
if t is None:
# Uniform(0, 1) per batch element; avoid exactly 0 and 1.
t = torch.rand(B, device=device, dtype=dtype)
else:
t = t.to(device=device, dtype=dtype)
if t.shape != (B,):
raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
if (t < 0).any() or (t > 1).any():
raise ValueError("t must be in [0, 1]")
# --- noise schedule: alpha_t = probability that a token is NOT masked ---
# Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
# refers to "log-linear" because the schedule is linear in the *log* domain
# of the forward process probability. We expose both names for clarity.
if alpha_schedule in ("linear", "loglinear"):
alpha_t = 1.0 - t # (B,) float, in [0, 1]
else:
raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")
# --- per-token Bernoulli mask ---
# alpha_t[:, None] broadcasts to (B, T).
alpha_t_expanded = alpha_t[:, None] # (B, 1)
# Bernoulli(1 - alpha_t) = 1 means "mask this token".
# We sample independently per token, per batch element.
rand = torch.rand(B, T, device=device, dtype=dtype)
mask_positions = rand > alpha_t_expanded # (B, T) bool
# True β masked position
# False β unmasked (kept as original)
# --- build x_t ---
x_t = targets.clone()
x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)
# --- RB loss weights: inverse mask probability on masked positions, 0 elsewhere ---
# MDLM's continuous-time factor is alpha'_t / (1 - alpha_t). With
# alpha_t = 1 - t, magnitude is 1 / t. Clamp mask_prob so weights stay
# finite near tβ0, where only rare masked tokens appear.
mask_prob = (1.0 - alpha_t).clamp(min=_MIN_MASK_PROB) # (B,)
weight_per_sample = 1.0 / mask_prob # (B,)
# Broadcast to (B, T) and zero out unmasked positions.
loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T)
loss_weights = loss_weights * mask_positions.float()
return x_t, mask_positions, loss_weights
def mdlm_rb_loss(
logits: torch.Tensor,
targets: torch.Tensor,
mask_positions: torch.Tensor,
loss_weights: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""Rao-Blackwellized negative ELBO.
Applies the MDLM loss: cross-entropy on masked positions only, weighted
per-token by loss_weights, averaged over the batch.
The formula (eq. 7-8 of arXiv:2406.07524):
L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
/ max(sum_T(mask_i), 1) ]
Args:
logits : (B, T, V) raw logits. May be bf16; internally cast to
float32 for CE computation.
targets : (B, T) int64 true token ids (x_0).
mask_positions: (B, T) bool β True = masked position.
loss_weights : (B, T) float32 β inverse mask probability on masked positions, 0 elsewhere.
ignore_index : Passed to F.cross_entropy; positions with this label
are excluded from the loss.
Returns:
Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
"""
B, T, V = logits.shape
# Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
# logits but accumulates in float internally anyway. Being explicit avoids
# silent precision surprises.
logits_f = logits.float() # (B, T, V)
# Build targets with ignore_index on UNmasked positions so CE only fires
# where mask_positions is True. We also honour any pre-existing -100 values
# (e.g. doc-separator masking upstream).
targets_masked = torch.where(
mask_positions & (targets != ignore_index),
targets,
torch.full_like(targets, ignore_index),
)
# Per-token CE; shape (B, T). Positions with ignore_index β 0 from CE.
per_tok_ce = F.cross_entropy(
logits_f.reshape(B * T, V),
targets_masked.reshape(B * T),
ignore_index=ignore_index,
reduction="none",
).reshape(B, T) # (B, T) float32
# Apply RB weight. loss_weights already has 0 on unmasked positions.
weighted = per_tok_ce * loss_weights # (B, T)
# Per-sample mean over masked positions, then average over batch.
mask_f = mask_positions.float() # (B, T)
per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,)
per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,)
return per_sample_loss.mean() # scalar float32
def mdlm_loss(
logits: torch.Tensor,
targets: torch.Tensor,
mask_token_id: int,
t: torch.Tensor | None = None,
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
ignore_index: int = -100,
) -> torch.Tensor:
"""Convenience wrapper: forward process + RB-ELBO in one call.
Suitable for the common case where the caller has full-vocab logits and
wants a drop-in replacement for a standard masked-LM CE loss.
Args:
logits : (B, T, V) raw logits.
targets : (B, T) int64 clean token ids.
mask_token_id : The MASK token id used to corrupt the input.
t : Optional (B,) timestep in (0, 1). Sampled if None.
alpha_schedule: "loglinear" (default) or "linear".
ignore_index : Token id to ignore in the loss (e.g. padding).
Returns:
Scalar float32 MDLM RB-ELBO loss.
Note on sampled-softmax / partial logits:
If your model only computes logits for a subset of vocab positions
(e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
"""
x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
targets=targets,
mask_token_id=mask_token_id,
t=t,
alpha_schedule=alpha_schedule,
)
# x_t is produced for the model's input (not used by this convenience
# wrapper since logits are already provided by the caller). In a real
# training loop the caller feeds x_t into the model to get logits, THEN
# calls this function. See the orchestrator wiring note in training.py.
return mdlm_rb_loss(
logits=logits,
targets=targets,
mask_positions=mask_positions,
loss_weights=loss_weights,
ignore_index=ignore_index,
)
|