"""MDLM Rao-Blackwellized Masked Diffusion Loss. Implements the masked-diffusion ELBO from: Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM), NeurIPS 2024, arXiv:2406.07524. Equations referenced: - Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t) - Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1) - RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (alpha'_t / (1 - alpha_t)) * CE(x_theta(x_t), x_0) ] where the expectation is over masked positions. For alpha_t = 1 - t, the magnitude is proportional to 1 / t, i.e. inverse mask probability, not inverse keep probability. Key insight: the Rao-Blackwellized estimate replaces an average over all masks (exponential) by a closed-form weighted CE that applies inverse mask-probability weight only on the positions that were masked, and 0 on unmasked positions. This gives an unbiased estimator with lower variance than a naive Monte Carlo over mask patterns. Reference implementation cross-checked against: https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss) """ from __future__ import annotations from typing import Literal import torch import torch.nn.functional as F # Clamping weight keeps gradients finite while still up-weighting high-noise # positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2 # launch (2026-04-22): loss 26 → 42 → NaN in 13 steps under Muon lr=7e-3 # because per-token CE × 1000 saturated the 100-unit FAIL guard. The MDLM # paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3 # (70× larger), so the weight clamp needs to compensate. # # Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable # weighting entirely (flat masked-LM CE, no RB reweighting — simpler and # more stable, sacrifices the theoretical ELBO property). import os as _os _MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0")) _MIN_MASK_PROB: float = 1.0 / _MAX_WEIGHT # so clamp(mask_prob, min=...) gives 1/mask_prob <= _MAX_WEIGHT # Back-compat export for older tests/scripts that imported _MIN_ALPHA. The # minimum now applies to mask probability t = 1 - alpha_t, not alpha_t itself. _MIN_ALPHA: float = _MIN_MASK_PROB # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def mdlm_masked_forward_process( targets: torch.Tensor, mask_token_id: int, t: torch.Tensor | None = None, alpha_schedule: Literal["linear", "loglinear"] = "loglinear", ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """MDLM forward (noising) process: mask tokens and compute RB weights. Args: targets: (B, T) int64 token ids — the clean sequence x_0. mask_token_id: The special token id used to represent a masked token. t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch element. t=0 means fully clean; t=1 means fully masked. alpha_schedule: Noise schedule. "loglinear" (MDLM default): alpha_t = 1 - t "linear": identical formula — both are provided for completeness since the paper calls the 1-t schedule "log-linear" in the context of the ELBO derivation. Returns: x_t : (B, T) int64 — noised sequence; masked positions hold mask_token_id, unmasked positions equal targets. mask_positions: (B, T) bool — True where the token was masked. loss_weights : (B, T) float32 — RB weighting factor. On masked positions: 1/(1-alpha_t), i.e. 1/mask_prob (clamped to _MAX_WEIGHT). On unmasked positions: 0.0. Summing (CE * loss_weights * mask_positions).sum() / mask.sum() gives the per-sample RB-ELBO estimator. """ B, T = targets.shape device = targets.device dtype = torch.float32 # --- sample or validate t --- if t is None: # Uniform(0, 1) per batch element; avoid exactly 0 and 1. t = torch.rand(B, device=device, dtype=dtype) else: t = t.to(device=device, dtype=dtype) if t.shape != (B,): raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}") if (t < 0).any() or (t > 1).any(): raise ValueError("t must be in [0, 1]") # --- noise schedule: alpha_t = probability that a token is NOT masked --- # Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper # refers to "log-linear" because the schedule is linear in the *log* domain # of the forward process probability. We expose both names for clarity. if alpha_schedule in ("linear", "loglinear"): alpha_t = 1.0 - t # (B,) float, in [0, 1] else: raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.") # --- per-token Bernoulli mask --- # alpha_t[:, None] broadcasts to (B, T). alpha_t_expanded = alpha_t[:, None] # (B, 1) # Bernoulli(1 - alpha_t) = 1 means "mask this token". # We sample independently per token, per batch element. rand = torch.rand(B, T, device=device, dtype=dtype) mask_positions = rand > alpha_t_expanded # (B, T) bool # True → masked position # False → unmasked (kept as original) # --- build x_t --- x_t = targets.clone() x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t) # --- RB loss weights: inverse mask probability on masked positions, 0 elsewhere --- # MDLM's continuous-time factor is alpha'_t / (1 - alpha_t). With # alpha_t = 1 - t, magnitude is 1 / t. Clamp mask_prob so weights stay # finite near t→0, where only rare masked tokens appear. mask_prob = (1.0 - alpha_t).clamp(min=_MIN_MASK_PROB) # (B,) weight_per_sample = 1.0 / mask_prob # (B,) # Broadcast to (B, T) and zero out unmasked positions. loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T) loss_weights = loss_weights * mask_positions.float() return x_t, mask_positions, loss_weights def mdlm_rb_loss( logits: torch.Tensor, targets: torch.Tensor, mask_positions: torch.Tensor, loss_weights: torch.Tensor, ignore_index: int = -100, ) -> torch.Tensor: """Rao-Blackwellized negative ELBO. Applies the MDLM loss: cross-entropy on masked positions only, weighted per-token by loss_weights, averaged over the batch. The formula (eq. 7-8 of arXiv:2406.07524): L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i) / max(sum_T(mask_i), 1) ] Args: logits : (B, T, V) raw logits. May be bf16; internally cast to float32 for CE computation. targets : (B, T) int64 true token ids (x_0). mask_positions: (B, T) bool — True = masked position. loss_weights : (B, T) float32 — inverse mask probability on masked positions, 0 elsewhere. ignore_index : Passed to F.cross_entropy; positions with this label are excluded from the loss. Returns: Scalar float32 loss. Returns 0.0 tensor if no positions are masked. """ B, T, V = logits.shape # Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16 # logits but accumulates in float internally anyway. Being explicit avoids # silent precision surprises. logits_f = logits.float() # (B, T, V) # Build targets with ignore_index on UNmasked positions so CE only fires # where mask_positions is True. We also honour any pre-existing -100 values # (e.g. doc-separator masking upstream). targets_masked = torch.where( mask_positions & (targets != ignore_index), targets, torch.full_like(targets, ignore_index), ) # Per-token CE; shape (B, T). Positions with ignore_index → 0 from CE. per_tok_ce = F.cross_entropy( logits_f.reshape(B * T, V), targets_masked.reshape(B * T), ignore_index=ignore_index, reduction="none", ).reshape(B, T) # (B, T) float32 # Apply RB weight. loss_weights already has 0 on unmasked positions. weighted = per_tok_ce * loss_weights # (B, T) # Per-sample mean over masked positions, then average over batch. mask_f = mask_positions.float() # (B, T) per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,) per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,) return per_sample_loss.mean() # scalar float32 def mdlm_loss( logits: torch.Tensor, targets: torch.Tensor, mask_token_id: int, t: torch.Tensor | None = None, alpha_schedule: Literal["linear", "loglinear"] = "loglinear", ignore_index: int = -100, ) -> torch.Tensor: """Convenience wrapper: forward process + RB-ELBO in one call. Suitable for the common case where the caller has full-vocab logits and wants a drop-in replacement for a standard masked-LM CE loss. Args: logits : (B, T, V) raw logits. targets : (B, T) int64 clean token ids. mask_token_id : The MASK token id used to corrupt the input. t : Optional (B,) timestep in (0, 1). Sampled if None. alpha_schedule: "loglinear" (default) or "linear". ignore_index : Token id to ignore in the loss (e.g. padding). Returns: Scalar float32 MDLM RB-ELBO loss. Note on sampled-softmax / partial logits: If your model only computes logits for a subset of vocab positions (e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits. """ x_t, mask_positions, loss_weights = mdlm_masked_forward_process( targets=targets, mask_token_id=mask_token_id, t=t, alpha_schedule=alpha_schedule, ) # x_t is produced for the model's input (not used by this convenience # wrapper since logits are already provided by the caller). In a real # training loop the caller feeds x_t into the model to get logits, THEN # calls this function. See the orchestrator wiring note in training.py. return mdlm_rb_loss( logits=logits, targets=targets, mask_positions=mask_positions, loss_weights=loss_weights, ignore_index=ignore_index, )