File size: 10,855 Bytes
c383594
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""MDLM Rao-Blackwellized Masked Diffusion Loss.

Implements the masked-diffusion ELBO from:
    Sahoo et al., "Simple and Effective Masked Diffusion Language Models" (MDLM),
    NeurIPS 2024, arXiv:2406.07524.

Equations referenced:
    - Forward process: eq. 2  (per-token Bernoulli masking at rate 1 - alpha_t)
    - Log-linear schedule:    alpha_t = 1 - t,  t ~ Uniform(0, 1)
    - RB-ELBO:     eq. 7-8   L_RB = E_t E_q [ (alpha'_t / (1 - alpha_t)) *
                              CE(x_theta(x_t), x_0) ] where the expectation is
                              over masked positions. For alpha_t = 1 - t, the
                              magnitude is proportional to 1 / t, i.e. inverse
                              mask probability, not inverse keep probability.

Key insight: the Rao-Blackwellized estimate replaces an average over all masks
(exponential) by a closed-form weighted CE that applies inverse mask-probability
weight only on the positions that were masked, and 0 on unmasked positions. This
gives an unbiased estimator with lower variance than a naive Monte Carlo over
mask patterns.

Reference implementation cross-checked against:
    https://github.com/kuleshov-group/mdlm  (diffusion.py::DiffusionModel._loss)
"""

from __future__ import annotations

from typing import Literal

import torch
import torch.nn.functional as F


# Clamping weight keeps gradients finite while still up-weighting high-noise
# positions. Historical value 1/eps=1000 blew up HYDRA training on a 12h v2
# launch (2026-04-22): loss 26 β†’ 42 β†’ NaN in 13 steps under Muon lr=7e-3
# because per-token CE Γ— 1000 saturated the 100-unit FAIL guard. The MDLM
# paper reports stable training at Adam lr=1e-4; HYDRA uses Muon at 7e-3
# (70Γ— larger), so the weight clamp needs to compensate.
#
# Tunable via HYDRA_MDLM_MAX_WEIGHT (default 5.0). Set =1.0 to disable
# weighting entirely (flat masked-LM CE, no RB reweighting β€” simpler and
# more stable, sacrifices the theoretical ELBO property).
import os as _os
_MAX_WEIGHT: float = float(_os.environ.get("HYDRA_MDLM_MAX_WEIGHT", "5.0"))
_MIN_MASK_PROB: float = 1.0 / _MAX_WEIGHT  # so clamp(mask_prob, min=...) gives 1/mask_prob <= _MAX_WEIGHT
# Back-compat export for older tests/scripts that imported _MIN_ALPHA. The
# minimum now applies to mask probability t = 1 - alpha_t, not alpha_t itself.
_MIN_ALPHA: float = _MIN_MASK_PROB


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

def mdlm_masked_forward_process(
    targets: torch.Tensor,
    mask_token_id: int,
    t: torch.Tensor | None = None,
    alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """MDLM forward (noising) process: mask tokens and compute RB weights.

    Args:
        targets: (B, T) int64 token ids β€” the clean sequence x_0.
        mask_token_id: The special token id used to represent a masked token.
        t: (B,) float in (0, 1). If None, samples Uniform(0, 1) per batch
            element. t=0 means fully clean; t=1 means fully masked.
        alpha_schedule: Noise schedule.
            "loglinear" (MDLM default): alpha_t = 1 - t
            "linear": identical formula β€” both are provided for completeness
            since the paper calls the 1-t schedule "log-linear" in the context
            of the ELBO derivation.

    Returns:
        x_t           : (B, T) int64 β€” noised sequence; masked positions hold
                        mask_token_id, unmasked positions equal targets.
        mask_positions: (B, T) bool  β€” True where the token was masked.
        loss_weights  : (B, T) float32 β€” RB weighting factor. On masked
                        positions: 1/(1-alpha_t), i.e. 1/mask_prob (clamped to
                        _MAX_WEIGHT). On
                        unmasked positions: 0.0. Summing
                        (CE * loss_weights * mask_positions).sum() / mask.sum()
                        gives the per-sample RB-ELBO estimator.
    """
    B, T = targets.shape
    device = targets.device
    dtype = torch.float32

    # --- sample or validate t ---
    if t is None:
        # Uniform(0, 1) per batch element; avoid exactly 0 and 1.
        t = torch.rand(B, device=device, dtype=dtype)
    else:
        t = t.to(device=device, dtype=dtype)
        if t.shape != (B,):
            raise ValueError(f"t must be shape (B,)={(B,)}, got {t.shape}")
        if (t < 0).any() or (t > 1).any():
            raise ValueError("t must be in [0, 1]")

    # --- noise schedule: alpha_t = probability that a token is NOT masked ---
    # Both "linear" and "loglinear" in MDLM use alpha_t = 1 - t; the paper
    # refers to "log-linear" because the schedule is linear in the *log* domain
    # of the forward process probability. We expose both names for clarity.
    if alpha_schedule in ("linear", "loglinear"):
        alpha_t = 1.0 - t          # (B,) float, in [0, 1]
    else:
        raise ValueError(f"Unknown alpha_schedule: {alpha_schedule!r}. Use 'linear' or 'loglinear'.")

    # --- per-token Bernoulli mask ---
    # alpha_t[:, None] broadcasts to (B, T).
    alpha_t_expanded = alpha_t[:, None]                # (B, 1)
    # Bernoulli(1 - alpha_t) = 1 means "mask this token".
    # We sample independently per token, per batch element.
    rand = torch.rand(B, T, device=device, dtype=dtype)
    mask_positions = rand > alpha_t_expanded           # (B, T) bool
    # True  β†’ masked position
    # False β†’ unmasked (kept as original)

    # --- build x_t ---
    x_t = targets.clone()
    x_t = torch.where(mask_positions, torch.full_like(x_t, mask_token_id), x_t)

    # --- RB loss weights: inverse mask probability on masked positions, 0 elsewhere ---
    # MDLM's continuous-time factor is alpha'_t / (1 - alpha_t). With
    # alpha_t = 1 - t, magnitude is 1 / t. Clamp mask_prob so weights stay
    # finite near t→0, where only rare masked tokens appear.
    mask_prob = (1.0 - alpha_t).clamp(min=_MIN_MASK_PROB)  # (B,)
    weight_per_sample = 1.0 / mask_prob                    # (B,)
    # Broadcast to (B, T) and zero out unmasked positions.
    loss_weights = weight_per_sample[:, None].expand(B, T).to(dtype=dtype)  # (B, T)
    loss_weights = loss_weights * mask_positions.float()

    return x_t, mask_positions, loss_weights


def mdlm_rb_loss(
    logits: torch.Tensor,
    targets: torch.Tensor,
    mask_positions: torch.Tensor,
    loss_weights: torch.Tensor,
    ignore_index: int = -100,
) -> torch.Tensor:
    """Rao-Blackwellized negative ELBO.

    Applies the MDLM loss: cross-entropy on masked positions only, weighted
    per-token by loss_weights, averaged over the batch.

    The formula (eq. 7-8 of arXiv:2406.07524):
        L_RB = mean_B [ sum_T (weight_t * CE(logits_i, target_i) * mask_i)
                        / max(sum_T(mask_i), 1) ]

    Args:
        logits        : (B, T, V) raw logits. May be bf16; internally cast to
                        float32 for CE computation.
        targets       : (B, T) int64 true token ids (x_0).
        mask_positions: (B, T) bool β€” True = masked position.
        loss_weights  : (B, T) float32 β€” inverse mask probability on masked positions, 0 elsewhere.
        ignore_index  : Passed to F.cross_entropy; positions with this label
                        are excluded from the loss.

    Returns:
        Scalar float32 loss. Returns 0.0 tensor if no positions are masked.
    """
    B, T, V = logits.shape

    # Ensure float32 for numerical stability; F.cross_entropy accepts fp16/bf16
    # logits but accumulates in float internally anyway. Being explicit avoids
    # silent precision surprises.
    logits_f = logits.float()                          # (B, T, V)

    # Build targets with ignore_index on UNmasked positions so CE only fires
    # where mask_positions is True. We also honour any pre-existing -100 values
    # (e.g. doc-separator masking upstream).
    targets_masked = torch.where(
        mask_positions & (targets != ignore_index),
        targets,
        torch.full_like(targets, ignore_index),
    )

    # Per-token CE; shape (B, T). Positions with ignore_index β†’ 0 from CE.
    per_tok_ce = F.cross_entropy(
        logits_f.reshape(B * T, V),
        targets_masked.reshape(B * T),
        ignore_index=ignore_index,
        reduction="none",
    ).reshape(B, T)                                    # (B, T) float32

    # Apply RB weight. loss_weights already has 0 on unmasked positions.
    weighted = per_tok_ce * loss_weights               # (B, T)

    # Per-sample mean over masked positions, then average over batch.
    mask_f = mask_positions.float()                    # (B, T)
    per_sample_mask_count = mask_f.sum(dim=1).clamp(min=1)   # (B,)
    per_sample_loss = weighted.sum(dim=1) / per_sample_mask_count  # (B,)

    return per_sample_loss.mean()                      # scalar float32


def mdlm_loss(
    logits: torch.Tensor,
    targets: torch.Tensor,
    mask_token_id: int,
    t: torch.Tensor | None = None,
    alpha_schedule: Literal["linear", "loglinear"] = "loglinear",
    ignore_index: int = -100,
) -> torch.Tensor:
    """Convenience wrapper: forward process + RB-ELBO in one call.

    Suitable for the common case where the caller has full-vocab logits and
    wants a drop-in replacement for a standard masked-LM CE loss.

    Args:
        logits        : (B, T, V) raw logits.
        targets       : (B, T) int64 clean token ids.
        mask_token_id : The MASK token id used to corrupt the input.
        t             : Optional (B,) timestep in (0, 1). Sampled if None.
        alpha_schedule: "loglinear" (default) or "linear".
        ignore_index  : Token id to ignore in the loss (e.g. padding).

    Returns:
        Scalar float32 MDLM RB-ELBO loss.

    Note on sampled-softmax / partial logits:
        If your model only computes logits for a subset of vocab positions
        (e.g. HYDRA's sampled-softmax head), call mdlm_masked_forward_process
        and mdlm_rb_loss separately. mdlm_rb_loss expects full-vocab logits.
    """
    x_t, mask_positions, loss_weights = mdlm_masked_forward_process(
        targets=targets,
        mask_token_id=mask_token_id,
        t=t,
        alpha_schedule=alpha_schedule,
    )
    # x_t is produced for the model's input (not used by this convenience
    # wrapper since logits are already provided by the caller). In a real
    # training loop the caller feeds x_t into the model to get logits, THEN
    # calls this function. See the orchestrator wiring note in training.py.
    return mdlm_rb_loss(
        logits=logits,
        targets=targets,
        mask_positions=mask_positions,
        loss_weights=loss_weights,
        ignore_index=ignore_index,
    )