Spaces:
Runtime error
Runtime error
| """MDLM Rao-Blackwellized Masked Diffusion Loss. | |
| Implements the masked-diffusion ELBO from: | |
| Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM), | |
| NeurIPS 2024, arXiv:2406.07524. | |
| Equations referenced: | |
| - Forward process: eq. 2 (per-token Bernoulli masking at rate 1 - alpha_t) | |
| - Log-linear schedule: alpha_t = 1 - t, t ~ Uniform(0, 1) | |
| - RB-ELBO: eq. 7-8 L_RB = E_t E_q [ (1/alpha_t) * CE(x_theta(x_t), x_0) ] | |
| where the expectation over masked positions. | |
| Key insight: the Rao-Blackwellized estimate replaces an average over all masks | |
| (exponential) by a closed-form weighted CE that applies weight 1/alpha_t only | |
| on the positions that were masked, and 0 on unmasked positions. This gives an | |
| unbiased estimator with lower variance than a naive Monte Carlo over mask | |
| patterns. | |
| Reference implementation cross-checked against: | |
| https://github.com/kuleshov-group/mdlm (diffusion.py::DiffusionModel._loss) | |
| """ | |
| from __future__ import annotations | |
| from typing import Literal | |
| import torch | |
| import torch.nn.functional as F | |
| # Clamping weight keeps gradients finite while still up-weighting high-noise | |
| # positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2 | |
| # launch (2026-04-22): loss 26 β 42 β NaN in 13 steps under Muon lr=7e-3 | |
| # because per-token CE Γ 1000 saturated the 100-unit FAIL guard. The MDLM | |
| # paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3 | |
| # (70Γ larger), so the weight clamp needs to compensate. | |
| # | |
| # Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable | |
| # weighting entirely (flat masked-LM CE, no RB reweighting β simpler and | |
| # more stable, sacrifices the theoretical ELBO property). | |
| import os as _os | |
| _MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0")) | |
| _MIN_ALPHA: float = 1.0 / _MAX_WEIGHT # so clamp(alpha, min=_MIN_ALPHA) gives 1/alpha <= _MAX_WEIGHT | |
| # --------------------------------------------------------------------------- | |
| # Public API | |
| # --------------------------------------------------------------------------- | |
| def mdlm_masked_forward_process( | |
| targets: torch.Tensor, | |
| mask_token_id: int, | |
| t: torch.Tensor | None = None, | |
| alpha_schedule: Literal["linear", "loglinear"] = "loglinear", | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """MDLM forward (noising) process: mask tokens and compute RB weights. | |
| Args: | |
| targets: (B, T) int64 token ids β the clean sequence x_0. | |
| mask_token_id: The special token id used to represent a masked token. | |
| t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch | |
| element. t=0 means fully clean; t=1 means fully masked. | |
| alpha_schedule: Noise schedule. | |
| "loglinear" (MDLM default): alpha_t = 1 - t | |
| "linear": identical formula β both are provided for completeness | |
| since the paper calls the 1-t schedule "log-linear" in the context | |
| of the ELBO derivation. | |
| Returns: | |
| x_t : (B, T) int64 β noised sequence; masked positions hold | |
| mask_token_id, unmasked positions equal targets. | |
| mask_positions: (B, T) bool β True where the token was masked. | |
| loss_weights : (B, T) float32 β RB weighting factor. On masked | |
| positions: 1/alpha_t (clamped to _MAX_WEIGHT). On | |
| unmasked positions: 0.0. Summing | |
| (CE * loss_weights * mask_positions).sum() / mask.sum() | |
| gives the per-sample RB-ELBO estimator. | |
| """ | |
| B, T = targets.shape | |
| device = targets.device | |
| dtype = torch.float32 | |
| # --- sample or validate t --- | |
| if t is None: | |
| # Uniform(0, 1) per batch element; avoid exactly 0 and 1. | |
| t = torch.rand(B, device=device, dtype=dtype) | |
| else: | |
| t = t.to(device=device, dtype=dtype) | |
| if t.shape != (B,): | |
| raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}") | |
| if (t < 0).any() or (t > 1).any(): | |
| raise ValueError("t must be in [0, 1]") | |
| # --- noise schedule: alpha_t = probability that a token is NOT masked --- | |
| # Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper | |
| # refers to "log-linear" because the schedule is linear in the *log* domain | |
| # of the forward process probability. We expose both names for clarity. | |
| if alpha_schedule in ("linear", "loglinear"): | |
| alpha_t = 1.0 - t # (B,) float, in [0, 1] | |
| else: | |
| raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.") | |
| # --- per-token Bernoulli mask --- | |
| # alpha_t[:, None] broadcasts to (B, T). | |
| alpha_t_expanded = alpha_t[:, None] # (B, 1) | |
| # Bernoulli(1 - alpha_t) = 1 means "mask this token". | |
| # We sample independently per token, per batch element. | |
| rand = torch.rand(B, T, device=device, dtype=dtype) | |
| mask_positions = rand > alpha_t_expanded # (B, T) bool | |
| # True β masked position | |
| # False β unmasked (kept as original) | |
| # --- build x_t --- | |
| x_t = targets.clone() | |
| x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t) | |
| # --- RB loss weights: 1/alpha_t on masked positions, 0 elsewhere --- | |
| # Clamp alpha_t so weights stay finite near tβ1. | |
| safe_alpha = alpha_t.clamp(min=_MIN_ALPHA) # (B,) | |
| weight_per_sample = 1.0 / safe_alpha # (B,) | |
| # Broadcast to (B, T) and zero out unmasked positions. | |
| loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype) # (B, T) | |
| loss_weights = loss_weights * mask_positions.float() | |
| return x_t, mask_positions, loss_weights | |
| def mdlm_rb_loss( | |
| logits: torch.Tensor, | |
| targets: torch.Tensor, | |
| mask_positions: torch.Tensor, | |
| loss_weights: torch.Tensor, | |
| ignore_index: int = -100, | |
| ) -> torch.Tensor: | |
| """Rao-Blackwellized negative ELBO. | |
| Applies the MDLM loss: cross-entropy on masked positions only, weighted | |
| per-token by loss_weights, averaged over the batch. | |
| The formula (eq. 7-8 of arXiv:2406.07524): | |
| L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i) | |
| / max(sum_T(mask_i), 1) ] | |
| Args: | |
| logits : (B, T, V) raw logits. May be bf16; internally cast to | |
| float32 for CE computation. | |
| targets : (B, T) int64 true token ids (x_0). | |
| mask_positions: (B, T) bool β True = masked position. | |
| loss_weights : (B, T) float32 β 1/alpha_t on masked positions, 0 elsewhere. | |
| ignore_index : Passed to F.cross_entropy; positions with this label | |
| are excluded from the loss. | |
| Returns: | |
| Scalar float32 loss. Returns 0.0 tensor if no positions are masked. | |
| """ | |
| B, T, V = logits.shape | |
| # Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16 | |
| # logits but accumulates in float internally anyway. Being explicit avoids | |
| # silent precision surprises. | |
| logits_f = logits.float() # (B, T, V) | |
| # Build targets with ignore_index on UNmasked positions so CE only fires | |
| # where mask_positions is True. We also honour any pre-existing -100 values | |
| # (e.g. doc-separator masking upstream). | |
| targets_masked = torch.where( | |
| mask_positions & (targets != ignore_index), | |
| targets, | |
| torch.full_like(targets, ignore_index), | |
| ) | |
| # Per-token CE; shape (B, T). Positions with ignore_index β 0 from CE. | |
| per_tok_ce = F.cross_entropy( | |
| logits_f.reshape(B * T, V), | |
| targets_masked.reshape(B * T), | |
| ignore_index=ignore_index, | |
| reduction="none", | |
| ).reshape(B, T) # (B, T) float32 | |
| # Apply RB weight. loss_weights already has 0 on unmasked positions. | |
| weighted = per_tok_ce * loss_weights # (B, T) | |
| # Per-sample mean over masked positions, then average over batch. | |
| mask_f = mask_positions.float() # (B, T) | |
| per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1) # (B,) | |
| per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count # (B,) | |
| return per_sample_loss.mean() # scalar float32 | |
| def mdlm_loss( | |
| logits: torch.Tensor, | |
| targets: torch.Tensor, | |
| mask_token_id: int, | |
| t: torch.Tensor | None = None, | |
| alpha_schedule: Literal["linear", "loglinear"] = "loglinear", | |
| ignore_index: int = -100, | |
| ) -> torch.Tensor: | |
| """Convenience wrapper: forward process + RB-ELBO in one call. | |
| Suitable for the common case where the caller has full-vocab logits and | |
| wants a drop-in replacement for a standard masked-LM CE loss. | |
| Args: | |
| logits : (B, T, V) raw logits. | |
| targets : (B, T) int64 clean token ids. | |
| mask_token_id : The MASK token id used to corrupt the input. | |
| t : Optional (B,) timestep in (0, 1). Sampled if None. | |
| alpha_schedule: "loglinear" (default) or "linear". | |
| ignore_index : Token id to ignore in the loss (e.g. padding). | |
| Returns: | |
| Scalar float32 MDLM RB-ELBO loss. | |
| Note on sampled-softmax / partial logits: | |
| If your model only computes logits for a subset of vocab positions | |
| (e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process | |
| and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits. | |
| """ | |
| x_t, mask_positions, loss_weights = mdlm_masked_forward_process( | |
| targets=targets, | |
| mask_token_id=mask_token_id, | |
| t=t, | |
| alpha_schedule=alpha_schedule, | |
| ) | |
| # x_t is produced for the model's input (not used by this convenience | |
| # wrapper since logits are already provided by the caller). In a real | |
| # training loop the caller feeds x_t into the model to get logits, THEN | |
| # calls this function. See the orchestrator wiring note in training.py. | |
| return mdlm_rb_loss( | |
| logits=logits, | |
| targets=targets, | |
| mask_positions=mask_positions, | |
| loss_weights=loss_weights, | |
| ignore_index=ignore_index, | |
| ) | |