File size: 12,445 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
"""Tests for hydra/diffusion_loss.py β€” MDLM Rao-Blackwellized loss.

Paper: Sahoo et al., "Simple and Effective Masked Diffusion Language Models"
       arXiv:2406.07524, NeurIPS 2024.
"""

from __future__ import annotations

import importlib.util
import math
import sys
from pathlib import Path

import pytest
import torch
import torch.nn.functional as F

# ---------------------------------------------------------------------------
# Import diffusion_loss directly from the file to avoid triggering
# hydra/__init__.py, which eagerly imports mamba_ssm (not available in the
# test environment without a GPU build). diffusion_loss.py has zero heavy deps.
# ---------------------------------------------------------------------------
_MODULE_PATH = Path(__file__).parent.parent / "hydra" / "diffusion_loss.py"
_spec = importlib.util.spec_from_file_location("hydra.diffusion_loss", _MODULE_PATH)
_diffusion_loss_mod = importlib.util.module_from_spec(_spec)  # type: ignore[arg-type]
sys.modules["hydra.diffusion_loss"] = _diffusion_loss_mod
_spec.loader.exec_module(_diffusion_loss_mod)  # type: ignore[union-attr]

_MAX_WEIGHT = _diffusion_loss_mod._MAX_WEIGHT
_MIN_ALPHA = _diffusion_loss_mod._MIN_ALPHA
mdlm_masked_forward_process = _diffusion_loss_mod.mdlm_masked_forward_process
mdlm_rb_loss = _diffusion_loss_mod.mdlm_rb_loss
mdlm_loss = _diffusion_loss_mod.mdlm_loss

# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------

B, T, V = 4, 32, 512
MASK_ID = 0


def _random_targets(b=B, t=T, v=V) -> torch.Tensor:
    """Random token ids in [1, V) so MASK_ID=0 is unambiguously special."""
    return torch.randint(1, v, (b, t))


def _random_logits(b=B, t=T, v=V) -> torch.Tensor:
    return torch.randn(b, t, v)


# ---------------------------------------------------------------------------
# test_forward_process_shape
# ---------------------------------------------------------------------------

def test_forward_process_shape():
    """x_t, mask_positions, loss_weights all have shape (B, T) with correct dtypes."""
    targets = _random_targets()
    x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)

    assert x_t.shape == (B, T), f"x_t shape: {x_t.shape}"
    assert mask.shape == (B, T), f"mask shape: {mask.shape}"
    assert weights.shape == (B, T), f"weights shape: {weights.shape}"

    assert x_t.dtype == torch.int64, f"x_t dtype: {x_t.dtype}"
    assert mask.dtype == torch.bool, f"mask dtype: {mask.dtype}"
    assert weights.dtype == torch.float32, f"weights dtype: {weights.dtype}"


def test_forward_process_values_consistent():
    """Masked positions get mask_token_id; unmasked positions keep original."""
    targets = _random_targets()
    x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)

    # Masked β†’ mask token id
    assert (x_t[mask] == MASK_ID).all(), "Masked positions should equal MASK_ID"
    # Unmasked β†’ original token
    assert (x_t[~mask] == targets[~mask]).all(), "Unmasked positions should equal original"
    # Weights non-zero only on masked positions
    assert (weights[~mask] == 0.0).all(), "Weights on unmasked positions should be 0"
    assert (weights[mask] > 0.0).all(), "Weights on masked positions should be > 0"


# ---------------------------------------------------------------------------
# test_mask_fraction
# ---------------------------------------------------------------------------

def test_mask_fraction():
    """Mean mask fraction over many samples approximates mean(t) = 0.5."""
    torch.manual_seed(42)
    n_trials = 2000
    total_mask = 0
    total_tokens = 0
    for _ in range(n_trials):
        targets = _random_targets(b=4, t=16)
        x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID)
        total_mask += mask.float().sum().item()
        total_tokens += mask.numel()

    empirical_frac = total_mask / total_tokens
    # Expected: E[mask_fraction] = E[1 - alpha_t] = E[t] = 0.5
    # With n_trials=2000 and B*T=64, std β‰ˆ 0.5/sqrt(n_trials*B*T) β‰ˆ 0.0014
    # Tolerance = 4 std β‰ˆ 0.006
    assert abs(empirical_frac - 0.5) < 0.01, (
        f"Expected mask fraction β‰ˆ 0.5, got {empirical_frac:.4f}"
    )


def test_mask_fraction_with_fixed_t():
    """With fixed t=0.3, mask fraction β‰ˆ 0.3 (i.e., 1 - alpha_t = 1 - 0.7 = 0.3)."""
    torch.manual_seed(7)
    n_trials = 1000
    t_val = 0.3
    total_mask = 0
    total_tokens = 0
    for _ in range(n_trials):
        targets = _random_targets(b=4, t=32)
        t = torch.full((4,), t_val)
        x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID, t=t)
        total_mask += mask.float().sum().item()
        total_tokens += mask.numel()

    empirical_frac = total_mask / total_tokens
    assert abs(empirical_frac - t_val) < 0.02, (
        f"Expected mask fraction β‰ˆ {t_val}, got {empirical_frac:.4f}"
    )


# ---------------------------------------------------------------------------
# test_unmasked_loss_zero
# ---------------------------------------------------------------------------

def test_unmasked_loss_zero():
    """When no positions are masked, rb_loss returns exactly 0."""
    targets = _random_targets()
    logits = _random_logits()

    # Force mask_positions = all False and weights = 0
    mask_positions = torch.zeros(B, T, dtype=torch.bool)
    loss_weights = torch.zeros(B, T)

    loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights)
    assert loss.item() == pytest.approx(0.0, abs=1e-6), (
        f"Expected 0.0 when nothing is masked, got {loss.item()}"
    )


# ---------------------------------------------------------------------------
# test_loss_scales_with_weight
# ---------------------------------------------------------------------------

def test_loss_scales_with_weight():
    """Doubling loss_weights doubles the loss (linearity)."""
    torch.manual_seed(1234)
    targets = _random_targets()
    logits = _random_logits()

    # Fix a mask (at least some positions must be True).
    mask_positions = torch.rand(B, T) < 0.5
    if not mask_positions.any():
        mask_positions[0, 0] = True
    base_weights = torch.rand(B, T).float() * mask_positions.float()

    loss1 = mdlm_rb_loss(logits, targets, mask_positions, base_weights)
    loss2 = mdlm_rb_loss(logits, targets, mask_positions, base_weights * 2.0)

    assert loss2.item() == pytest.approx(loss1.item() * 2.0, rel=1e-5), (
        f"Expected 2x scaling: {loss1.item():.6f} * 2 β‰  {loss2.item():.6f}"
    )


# ---------------------------------------------------------------------------
# test_ce_matches_reference
# ---------------------------------------------------------------------------

def test_ce_matches_reference():
    """On a tiny deterministic case, compare against manual numpy CE."""
    torch.manual_seed(99)
    B2, T2, V2 = 2, 4, 8
    targets = torch.tensor([[1, 2, 3, 1], [2, 3, 0, 1]])  # NOTE: token 0 = MASK_ID
    # Actually use targets without MASK_ID so they are all "real" tokens
    targets = torch.tensor([[1, 2, 3, 4], [2, 3, 5, 6]])

    # Fixed logits (all zeros β†’ uniform distribution β†’ CE = log(V))
    logits = torch.zeros(B2, T2, V2)

    # Fixed mask: mask positions (0,0), (0,2), (1,1), (1,3)
    mask_positions = torch.tensor([
        [True, False, True, False],
        [False, True, False, True],
    ])
    # Fixed alpha_t: row 0 β†’ alpha=0.5, row 1 β†’ alpha=0.25
    # Loss weights: row 0 β†’ 1/0.5=2 on masked, row 1 β†’ 1/0.25=4 on masked
    alpha = torch.tensor([0.5, 0.25])
    loss_weights = torch.zeros(B2, T2)
    for i in range(B2):
        for j in range(T2):
            if mask_positions[i, j]:
                loss_weights[i, j] = 1.0 / alpha[i].item()

    loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights)

    # Manual reference via numpy:
    # CE(uniform over V2=8) = log(8) = ln(8)
    ce_ref = math.log(V2)

    # Row 0: 2 masked positions, each weight=2, CE=ln(8)
    #   weighted_sum = 2 * 2.0 * ln(8)
    #   per_sample   = (2 * 2.0 * ln(8)) / 2 = 2.0 * ln(8)
    row0_loss = 2.0 * ce_ref
    # Row 1: 2 masked positions, each weight=4, CE=ln(8)
    #   weighted_sum = 2 * 4.0 * ln(8)
    #   per_sample   = (2 * 4.0 * ln(8)) / 2 = 4.0 * ln(8)
    row1_loss = 4.0 * ce_ref
    expected = (row0_loss + row1_loss) / 2.0

    assert loss.item() == pytest.approx(expected, rel=1e-4), (
        f"Expected {expected:.6f}, got {loss.item():.6f}"
    )


# ---------------------------------------------------------------------------
# test_autograd_bf16
# ---------------------------------------------------------------------------

def test_autograd_bf16():
    """Loss is fp32 and backward produces finite grads even with bf16 logits."""
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")

    torch.manual_seed(42)
    B3, T3, V3 = 2, 16, V

    device = torch.device("cuda")
    targets = _random_targets(b=B3, t=T3).to(device)
    logits_bf16 = torch.randn(B3, T3, V3, device=device, dtype=torch.bfloat16,
                              requires_grad=True)

    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)

    loss = mdlm_rb_loss(logits_bf16, targets, mask, weights)

    # Loss must be float32
    assert loss.dtype == torch.float32, f"Expected float32 loss, got {loss.dtype}"

    # Backward must succeed and produce finite grads
    loss.backward()

    assert logits_bf16.grad is not None, "No gradient on logits"
    assert torch.isfinite(logits_bf16.grad).all(), "Inf/NaN in gradient"


# ---------------------------------------------------------------------------
# test_t_validation
# ---------------------------------------------------------------------------

def test_t_shape_error():
    """Wrong t shape raises ValueError."""
    targets = _random_targets()
    bad_t = torch.rand(B + 1)
    with pytest.raises(ValueError, match="shape"):
        mdlm_masked_forward_process(targets, MASK_ID, t=bad_t)


def test_t_range_error():
    """t outside [0, 1] raises ValueError."""
    targets = _random_targets()
    bad_t = torch.rand(B) + 1.5  # all > 1
    with pytest.raises(ValueError, match="\\[0, 1\\]"):
        mdlm_masked_forward_process(targets, MASK_ID, t=bad_t)


# ---------------------------------------------------------------------------
# test_weight_clamping
# ---------------------------------------------------------------------------

def test_weight_clamping():
    """Loss weights capped at _MAX_WEIGHT even when t β†’ 1 (alpha_t β†’ 0)."""
    targets = _random_targets()
    # t very close to 1 β†’ alpha_t very close to 0
    t = torch.full((B,), 1.0 - 1e-9)
    x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID, t=t)
    assert (weights <= _MAX_WEIGHT + 1e-6).all(), (
        f"Weight exceeded _MAX_WEIGHT={_MAX_WEIGHT}; max={weights.max().item()}"
    )


# ---------------------------------------------------------------------------
# test_convenience_wrapper
# ---------------------------------------------------------------------------

def test_mdlm_loss_convenience():
    """mdlm_loss end-to-end returns a scalar float32 loss."""
    torch.manual_seed(0)
    targets = _random_targets()
    logits = _random_logits()
    loss = mdlm_loss(logits, targets, MASK_ID)
    assert loss.ndim == 0, "Expected scalar loss"
    assert loss.dtype == torch.float32
    assert torch.isfinite(loss), f"Non-finite loss: {loss.item()}"


def test_mdlm_loss_no_side_effects():
    """mdlm_loss does not mutate targets or logits tensors."""
    targets = _random_targets()
    logits = _random_logits()
    targets_copy = targets.clone()
    logits_copy = logits.clone()
    _ = mdlm_loss(logits, targets, MASK_ID)
    assert (targets == targets_copy).all(), "targets was mutated"
    assert (logits == logits_copy).all(), "logits was mutated"


# ---------------------------------------------------------------------------
# test_alpha_schedule_unknown
# ---------------------------------------------------------------------------

def test_alpha_schedule_unknown():
    """Unknown alpha_schedule raises ValueError."""
    targets = _random_targets()
    with pytest.raises(ValueError, match="Unknown alpha_schedule"):
        mdlm_masked_forward_process(targets, MASK_ID, alpha_schedule="cosine")  # type: ignore