icarus112's picture
Upload folder using huggingface_hub
c383594 verified
"""MDLM Rao-Blackwellized Masked Diffusion Loss.
Implements the masked-diffusion ELBO from:
Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
NeurIPS 2024, arXiv:2406.07524.
Equations referenced:
- Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t)
- Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1)
- RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (alpha'_t / (1 - alpha_t)) *
CE(x_theta(x_t), x_0) ] where the expectation is
over masked positions. For alpha_t = 1 - t, the
magnitude is proportional to 1 / t, i.e. inverse
mask probability, not inverse keep probability.
Key insight: the Rao-Blackwellized estimate replaces an average over all masks
(exponential) by a closed-form weighted CE that applies inverse mask-probability
weight only on the positions that were masked, and 0 on unmasked positions. This
gives an unbiased estimator with lower variance than a naive Monte Carlo over
mask patterns.
Reference implementation cross-checked against:
https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss)
"""
from __future__ import annotations
from typing import Literal
import torch
import torch.nn.functional as F
# Clamping weight keeps gradients finite while still up-weighting high-noise
# positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
# launch (2026-04-22): loss 26 β†’ 42 β†’ NaN in 13 steps under Muon lr=7e-3
# because per-token CE Γ— 1000 saturated the 100-unit FAIL guard. The MDLM
# paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
# (70Γ— larger), so the weight clamp needs to compensate.
#
# Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
# weighting entirely (flat masked-LM CE, no RB reweighting β€” simpler and
# more stable, sacrifices the theoretical ELBO property).
import os as _os
_MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
_MIN_MASK_PROB: float = 1.0 / _MAX_WEIGHT # so clamp(mask_prob, min=...) gives 1/mask_prob <= _MAX_WEIGHT
# Back-compat export for older tests/scripts that imported _MIN_ALPHA. The
# minimum now applies to mask probability t = 1 - alpha_t, not alpha_t itself.
_MIN_ALPHA: float = _MIN_MASK_PROB
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def mdlm_masked_forward_process(
targets: torch.Tensor,
mask_token_id: int,
t: torch.Tensor | None = None,
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""MDLM forward (noising) process: mask tokens and compute RB weights.
Args:
targets: (B, T) int64 token ids β€” the clean sequence x_0.
mask_token_id: The special token id used to represent a masked token.
t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
element. t=0 means fully clean; t=1 means fully masked.
alpha_schedule: Noise schedule.
"loglinear" (MDLM default): alpha_t = 1 - t
"linear": identical formula β€” both are provided for completeness
since the paper calls the 1-t schedule "log-linear" in the context
of the ELBO derivation.
Returns:
x_t : (B, T) int64 β€” noised sequence; masked positions hold
mask_token_id, unmasked positions equal targets.
mask_positions: (B, T) bool β€” True where the token was masked.
loss_weights : (B, T) float32 β€” RB weighting factor. On masked
positions: 1/(1-alpha_t), i.e. 1/mask_prob (clamped to
_MAX_WEIGHT). On
unmasked positions: 0.0. Summing
(CE * loss_weights * mask_positions).sum() / mask.sum()
gives the per-sample RB-ELBO estimator.
"""
B, T = targets.shape
device = targets.device
dtype = torch.float32
# --- sample or validate t ---
if t is None:
# Uniform(0, 1) per batch element; avoid exactly 0 and 1.
t = torch.rand(B, device=device, dtype=dtype)
else:
t = t.to(device=device, dtype=dtype)
if t.shape != (B,):
raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
if (t < 0).any() or (t > 1).any():
raise ValueError("t must be in [0, 1]")
# --- noise schedule: alpha_t = probability that a token is NOT masked ---
# Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
# refers to "log-linear" because the schedule is linear in the *log* domain
# of the forward process probability. We expose both names for clarity.
if alpha_schedule in ("linear", "loglinear"):
alpha_t = 1.0 - t # (B,) float, in [0, 1]
else:
raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")
# --- per-token Bernoulli mask ---
# alpha_t[:, None] broadcasts to (B, T).
alpha_t_expanded = alpha_t[:, None] # (B, 1)
# Bernoulli(1 - alpha_t) = 1 means "mask this token".
# We sample independently per token, per batch element.
rand = torch.rand(B, T, device=device, dtype=dtype)
mask_positions = rand > alpha_t_expanded # (B, T) bool
# True β†’ masked position
# False β†’ unmasked (kept as original)
# --- build x_t ---
x_t = targets.clone()
x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)
# --- RB loss weights: inverse mask probability on masked positions, 0 elsewhere ---
# MDLM's continuous-time factor is alpha'_t / (1 - alpha_t). With
# alpha_t = 1 - t, magnitude is 1 / t. Clamp mask_prob so weights stay
# finite near t→0, where only rare masked tokens appear.
mask_prob = (1.0 - alpha_t).clamp(min=_MIN_MASK_PROB) # (B,)
weight_per_sample = 1.0 / mask_prob # (B,)
# Broadcast to (B, T) and zero out unmasked positions.
loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T)
loss_weights = loss_weights * mask_positions.float()
return x_t, mask_positions, loss_weights
def mdlm_rb_loss(
logits: torch.Tensor,
targets: torch.Tensor,
mask_positions: torch.Tensor,
loss_weights: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""Rao-Blackwellized negative ELBO.
Applies the MDLM loss: cross-entropy on masked positions only, weighted
per-token by loss_weights, averaged over the batch.
The formula (eq. 7-8 of arXiv:2406.07524):
L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
/ max(sum_T(mask_i), 1) ]
Args:
logits : (B, T, V) raw logits. May be bf16; internally cast to
float32 for CE computation.
targets : (B, T) int64 true token ids (x_0).
mask_positions: (B, T) bool β€” True = masked position.
loss_weights : (B, T) float32 β€” inverse mask probability on masked positions, 0 elsewhere.
ignore_index : Passed to F.cross_entropy; positions with this label
are excluded from the loss.
Returns:
Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
"""
B, T, V = logits.shape
# Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
# logits but accumulates in float internally anyway. Being explicit avoids
# silent precision surprises.
logits_f = logits.float() # (B, T, V)
# Build targets with ignore_index on UNmasked positions so CE only fires
# where mask_positions is True. We also honour any pre-existing -100 values
# (e.g. doc-separator masking upstream).
targets_masked = torch.where(
mask_positions & (targets != ignore_index),
targets,
torch.full_like(targets, ignore_index),
)
# Per-token CE; shape (B, T). Positions with ignore_index β†’ 0 from CE.
per_tok_ce = F.cross_entropy(
logits_f.reshape(B * T, V),
targets_masked.reshape(B * T),
ignore_index=ignore_index,
reduction="none",
).reshape(B, T) # (B, T) float32
# Apply RB weight. loss_weights already has 0 on unmasked positions.
weighted = per_tok_ce * loss_weights # (B, T)
# Per-sample mean over masked positions, then average over batch.
mask_f = mask_positions.float() # (B, T)
per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,)
per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,)
return per_sample_loss.mean() # scalar float32
def mdlm_loss(
logits: torch.Tensor,
targets: torch.Tensor,
mask_token_id: int,
t: torch.Tensor | None = None,
alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
ignore_index: int = -100,
) -> torch.Tensor:
"""Convenience wrapper: forward process + RB-ELBO in one call.
Suitable for the common case where the caller has full-vocab logits and
wants a drop-in replacement for a standard masked-LM CE loss.
Args:
logits : (B, T, V) raw logits.
targets : (B, T) int64 clean token ids.
mask_token_id : The MASK token id used to corrupt the input.
t : Optional (B,) timestep in (0, 1). Sampled if None.
alpha_schedule: "loglinear" (default) or "linear".
ignore_index : Token id to ignore in the loss (e.g. padding).
Returns:
Scalar float32 MDLM RB-ELBO loss.
Note on sampled-softmax / partial logits:
If your model only computes logits for a subset of vocab positions
(e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
"""
x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
targets=targets,
mask_token_id=mask_token_id,
t=t,
alpha_schedule=alpha_schedule,
)
# x_t is produced for the model's input (not used by this convenience
# wrapper since logits are already provided by the caller). In a real
# training loop the caller feeds x_t into the model to get logits, THEN
# calls this function. See the orchestrator wiring note in training.py.
return mdlm_rb_loss(
logits=logits,
targets=targets,
mask_positions=mask_positions,
loss_weights=loss_weights,
ignore_index=ignore_index,
)