| """MDLM Rao-Blackwellized Masked Diffusion Loss. |
| |
| Implements the masked-diffusion ELBO from: |
| Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM), |
| NeurIPS 2024, arXiv:2406.07524. |
| |
| Equations referenced: |
| - Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t) |
| - Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1) |
| - RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (alpha'_t / (1 - alpha_t)) * |
| CE(x_theta(x_t), x_0) ] where the expectation is |
| over masked positions. For alpha_t = 1 - t, the |
| magnitude is proportional to 1 / t, i.e. inverse |
| mask probability, not inverse keep probability. |
| |
| Key insight: the Rao-Blackwellized estimate replaces an average over all masks |
| (exponential) by a closed-form weighted CE that applies inverse mask-probability |
| weight only on the positions that were masked, and 0 on unmasked positions. This |
| gives an unbiased estimator with lower variance than a naive Monte Carlo over |
| mask patterns. |
| |
| Reference implementation cross-checked against: |
| https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss) |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Literal |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import os as _os |
| _MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0")) |
| _MIN_MASK_PROB: float = 1.0 / _MAX_WEIGHT |
| |
| |
| _MIN_ALPHA: float = _MIN_MASK_PROB |
|
|
|
|
| |
| |
| |
|
|
| def mdlm_masked_forward_process( |
| targets: torch.Tensor, |
| mask_token_id: int, |
| t: torch.Tensor | None = None, |
| alpha_schedule: Literal["linear", "loglinear"] = "loglinear", |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """MDLM forward (noising) process: mask tokens and compute RB weights. |
| |
| Args: |
| targets: (B, T) int64 token ids β the clean sequence x_0. |
| mask_token_id: The special token id used to represent a masked token. |
| t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch |
| element. t=0 means fully clean; t=1 means fully masked. |
| alpha_schedule: Noise schedule. |
| "loglinear" (MDLM default): alpha_t = 1 - t |
| "linear": identical formula β both are provided for completeness |
| since the paper calls the 1-t schedule "log-linear" in the context |
| of the ELBO derivation. |
| |
| Returns: |
| x_t : (B, T) int64 β noised sequence; masked positions hold |
| mask_token_id, unmasked positions equal targets. |
| mask_positions: (B, T) bool β True where the token was masked. |
| loss_weights : (B, T) float32 β RB weighting factor. On masked |
| positions: 1/(1-alpha_t), i.e. 1/mask_prob (clamped to |
| _MAX_WEIGHT). On |
| unmasked positions: 0.0. Summing |
| (CE * loss_weights * mask_positions).sum() / mask.sum() |
| gives the per-sample RB-ELBO estimator. |
| """ |
| B, T = targets.shape |
| device = targets.device |
| dtype = torch.float32 |
|
|
| |
| if t is None: |
| |
| t = torch.rand(B, device=device, dtype=dtype) |
| else: |
| t = t.to(device=device, dtype=dtype) |
| if t.shape != (B,): |
| raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}") |
| if (t < 0).any() or (t > 1).any(): |
| raise ValueError("t must be in [0, 1]") |
|
|
| |
| |
| |
| |
| if alpha_schedule in ("linear", "loglinear"): |
| alpha_t = 1.0 - t |
| else: |
| raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.") |
|
|
| |
| |
| alpha_t_expanded = alpha_t[:, None] |
| |
| |
| rand = torch.rand(B, T, device=device, dtype=dtype) |
| mask_positions = rand > alpha_t_expanded |
| |
| |
|
|
| |
| x_t = targets.clone() |
| x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t) |
|
|
| |
| |
| |
| |
| mask_prob = (1.0 - alpha_t).clamp(min=_MIN_MASK_PROB) |
| weight_per_sample = 1.0 / mask_prob |
| |
| loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) |
| loss_weights = loss_weights * mask_positions.float() |
|
|
| return x_t, mask_positions, loss_weights |
|
|
|
|
| def mdlm_rb_loss( |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| mask_positions: torch.Tensor, |
| loss_weights: torch.Tensor, |
| ignore_index: int = -100, |
| ) -> torch.Tensor: |
| """Rao-Blackwellized negative ELBO. |
| |
| Applies the MDLM loss: cross-entropy on masked positions only, weighted |
| per-token by loss_weights, averaged over the batch. |
| |
| The formula (eq. 7-8 of arXiv:2406.07524): |
| L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i) |
| / max(sum_T(mask_i), 1) ] |
| |
| Args: |
| logits : (B, T, V) raw logits. May be bf16; internally cast to |
| float32 for CE computation. |
| targets : (B, T) int64 true token ids (x_0). |
| mask_positions: (B, T) bool β True = masked position. |
| loss_weights : (B, T) float32 β inverse mask probability on masked positions, 0 elsewhere. |
| ignore_index : Passed to F.cross_entropy; positions with this label |
| are excluded from the loss. |
| |
| Returns: |
| Scalar float32 loss. Returns 0.0 tensor if no positions are masked. |
| """ |
| B, T, V = logits.shape |
|
|
| |
| |
| |
| logits_f = logits.float() |
|
|
| |
| |
| |
| targets_masked = torch.where( |
| mask_positions & (targets != ignore_index), |
| targets, |
| torch.full_like(targets, ignore_index), |
| ) |
|
|
| |
| per_tok_ce = F.cross_entropy( |
| logits_f.reshape(B * T, V), |
| targets_masked.reshape(B * T), |
| ignore_index=ignore_index, |
| reduction="none", |
| ).reshape(B, T) |
|
|
| |
| weighted = per_tok_ce * loss_weights |
|
|
| |
| mask_f = mask_positions.float() |
| per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) |
| per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count |
|
|
| return per_sample_loss.mean() |
|
|
|
|
| def mdlm_loss( |
| logits: torch.Tensor, |
| targets: torch.Tensor, |
| mask_token_id: int, |
| t: torch.Tensor | None = None, |
| alpha_schedule: Literal["linear", "loglinear"] = "loglinear", |
| ignore_index: int = -100, |
| ) -> torch.Tensor: |
| """Convenience wrapper: forward process + RB-ELBO in one call. |
| |
| Suitable for the common case where the caller has full-vocab logits and |
| wants a drop-in replacement for a standard masked-LM CE loss. |
| |
| Args: |
| logits : (B, T, V) raw logits. |
| targets : (B, T) int64 clean token ids. |
| mask_token_id : The MASK token id used to corrupt the input. |
| t : Optional (B,) timestep in (0, 1). Sampled if None. |
| alpha_schedule: "loglinear" (default) or "linear". |
| ignore_index : Token id to ignore in the loss (e.g. padding). |
| |
| Returns: |
| Scalar float32 MDLM RB-ELBO loss. |
| |
| Note on sampled-softmax / partial logits: |
| If your model only computes logits for a subset of vocab positions |
| (e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process |
| and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits. |
| """ |
| x_t, mask_positions, loss_weights = mdlm_masked_forward_process( |
| targets=targets, |
| mask_token_id=mask_token_id, |
| t=t, |
| alpha_schedule=alpha_schedule, |
| ) |
| |
| |
| |
| |
| return mdlm_rb_loss( |
| logits=logits, |
| targets=targets, |
| mask_positions=mask_positions, |
| loss_weights=loss_weights, |
| ignore_index=ignore_index, |
| ) |
|
|