"""Tests for hydra/diffusion_loss.py — MDLM Rao-Blackwellized loss. Paper: Sahoo et al., "Simple and Effective Masked Diffusion Language Models" arXiv:2406.07524, NeurIPS 2024. """ from __future__ import annotations import importlib.util import math import sys from pathlib import Path import pytest import torch import torch.nn.functional as F # --------------------------------------------------------------------------- # Import diffusion_loss directly from the file to avoid triggering # hydra/__init__.py, which eagerly imports mamba_ssm (not available in the # test environment without a GPU build). diffusion_loss.py has zero heavy deps. # --------------------------------------------------------------------------- _MODULE_PATH = Path(__file__).parent.parent / "hydra" / "diffusion_loss.py" _spec = importlib.util.spec_from_file_location("hydra.diffusion_loss", _MODULE_PATH) _diffusion_loss_mod = importlib.util.module_from_spec(_spec) # type: ignore[arg-type] sys.modules["hydra.diffusion_loss"] = _diffusion_loss_mod _spec.loader.exec_module(_diffusion_loss_mod) # type: ignore[union-attr] _MAX_WEIGHT = _diffusion_loss_mod._MAX_WEIGHT _MIN_ALPHA = _diffusion_loss_mod._MIN_ALPHA mdlm_masked_forward_process = _diffusion_loss_mod.mdlm_masked_forward_process mdlm_rb_loss = _diffusion_loss_mod.mdlm_rb_loss mdlm_loss = _diffusion_loss_mod.mdlm_loss # --------------------------------------------------------------------------- # Fixtures / helpers # --------------------------------------------------------------------------- B, T, V = 4, 32, 512 MASK_ID = 0 def _random_targets(b=B, t=T, v=V) -> torch.Tensor: """Random token ids in [1, V) so MASK_ID=0 is unambiguously special.""" return torch.randint(1, v, (b, t)) def _random_logits(b=B, t=T, v=V) -> torch.Tensor: return torch.randn(b, t, v) # --------------------------------------------------------------------------- # test_forward_process_shape # --------------------------------------------------------------------------- def test_forward_process_shape(): """x_t, mask_positions, loss_weights all have shape (B, T) with correct dtypes.""" targets = _random_targets() x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID) assert x_t.shape == (B, T), f"x_t shape: {x_t.shape}" assert mask.shape == (B, T), f"mask shape: {mask.shape}" assert weights.shape == (B, T), f"weights shape: {weights.shape}" assert x_t.dtype == torch.int64, f"x_t dtype: {x_t.dtype}" assert mask.dtype == torch.bool, f"mask dtype: {mask.dtype}" assert weights.dtype == torch.float32, f"weights dtype: {weights.dtype}" def test_forward_process_values_consistent(): """Masked positions get mask_token_id; unmasked positions keep original.""" targets = _random_targets() x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID) # Masked → mask token id assert (x_t[mask] == MASK_ID).all(), "Masked positions should equal MASK_ID" # Unmasked → original token assert (x_t[~mask] == targets[~mask]).all(), "Unmasked positions should equal original" # Weights non-zero only on masked positions assert (weights[~mask] == 0.0).all(), "Weights on unmasked positions should be 0" assert (weights[mask] > 0.0).all(), "Weights on masked positions should be > 0" # --------------------------------------------------------------------------- # test_mask_fraction # --------------------------------------------------------------------------- def test_mask_fraction(): """Mean mask fraction over many samples approximates mean(t) = 0.5.""" torch.manual_seed(42) n_trials = 2000 total_mask = 0 total_tokens = 0 for _ in range(n_trials): targets = _random_targets(b=4, t=16) x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID) total_mask += mask.float().sum().item() total_tokens += mask.numel() empirical_frac = total_mask / total_tokens # Expected: E[mask_fraction] = E[1 - alpha_t] = E[t] = 0.5 # With n_trials=2000 and B*T=64, std ≈ 0.5/sqrt(n_trials*B*T) ≈ 0.0014 # Tolerance = 4 std ≈ 0.006 assert abs(empirical_frac - 0.5) < 0.01, ( f"Expected mask fraction ≈ 0.5, got {empirical_frac:.4f}" ) def test_mask_fraction_with_fixed_t(): """With fixed t=0.3, mask fraction ≈ 0.3 (i.e., 1 - alpha_t = 1 - 0.7 = 0.3).""" torch.manual_seed(7) n_trials = 1000 t_val = 0.3 total_mask = 0 total_tokens = 0 for _ in range(n_trials): targets = _random_targets(b=4, t=32) t = torch.full((4,), t_val) x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID, t=t) total_mask += mask.float().sum().item() total_tokens += mask.numel() empirical_frac = total_mask / total_tokens assert abs(empirical_frac - t_val) < 0.02, ( f"Expected mask fraction ≈ {t_val}, got {empirical_frac:.4f}" ) # --------------------------------------------------------------------------- # test_unmasked_loss_zero # --------------------------------------------------------------------------- def test_unmasked_loss_zero(): """When no positions are masked, rb_loss returns exactly 0.""" targets = _random_targets() logits = _random_logits() # Force mask_positions = all False and weights = 0 mask_positions = torch.zeros(B, T, dtype=torch.bool) loss_weights = torch.zeros(B, T) loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights) assert loss.item() == pytest.approx(0.0, abs=1e-6), ( f"Expected 0.0 when nothing is masked, got {loss.item()}" ) # --------------------------------------------------------------------------- # test_loss_scales_with_weight # --------------------------------------------------------------------------- def test_loss_scales_with_weight(): """Doubling loss_weights doubles the loss (linearity).""" torch.manual_seed(1234) targets = _random_targets() logits = _random_logits() # Fix a mask (at least some positions must be True). mask_positions = torch.rand(B, T) < 0.5 if not mask_positions.any(): mask_positions[0, 0] = True base_weights = torch.rand(B, T).float() * mask_positions.float() loss1 = mdlm_rb_loss(logits, targets, mask_positions, base_weights) loss2 = mdlm_rb_loss(logits, targets, mask_positions, base_weights * 2.0) assert loss2.item() == pytest.approx(loss1.item() * 2.0, rel=1e-5), ( f"Expected 2x scaling: {loss1.item():.6f} * 2 ≠ {loss2.item():.6f}" ) # --------------------------------------------------------------------------- # test_ce_matches_reference # --------------------------------------------------------------------------- def test_ce_matches_reference(): """On a tiny deterministic case, compare against manual numpy CE.""" torch.manual_seed(99) B2, T2, V2 = 2, 4, 8 targets = torch.tensor([[1, 2, 3, 1], [2, 3, 0, 1]]) # NOTE: token 0 = MASK_ID # Actually use targets without MASK_ID so they are all "real" tokens targets = torch.tensor([[1, 2, 3, 4], [2, 3, 5, 6]]) # Fixed logits (all zeros → uniform distribution → CE = log(V)) logits = torch.zeros(B2, T2, V2) # Fixed mask: mask positions (0,0), (0,2), (1,1), (1,3) mask_positions = torch.tensor([ [True, False, True, False], [False, True, False, True], ]) # Fixed alpha_t: row 0 → alpha=0.5, row 1 → alpha=0.25 # Loss weights: row 0 → 1/0.5=2 on masked, row 1 → 1/0.25=4 on masked alpha = torch.tensor([0.5, 0.25]) loss_weights = torch.zeros(B2, T2) for i in range(B2): for j in range(T2): if mask_positions[i, j]: loss_weights[i, j] = 1.0 / alpha[i].item() loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights) # Manual reference via numpy: # CE(uniform over V2=8) = log(8) = ln(8) ce_ref = math.log(V2) # Row 0: 2 masked positions, each weight=2, CE=ln(8) # weighted_sum = 2 * 2.0 * ln(8) # per_sample = (2 * 2.0 * ln(8)) / 2 = 2.0 * ln(8) row0_loss = 2.0 * ce_ref # Row 1: 2 masked positions, each weight=4, CE=ln(8) # weighted_sum = 2 * 4.0 * ln(8) # per_sample = (2 * 4.0 * ln(8)) / 2 = 4.0 * ln(8) row1_loss = 4.0 * ce_ref expected = (row0_loss + row1_loss) / 2.0 assert loss.item() == pytest.approx(expected, rel=1e-4), ( f"Expected {expected:.6f}, got {loss.item():.6f}" ) # --------------------------------------------------------------------------- # test_autograd_bf16 # --------------------------------------------------------------------------- def test_autograd_bf16(): """Loss is fp32 and backward produces finite grads even with bf16 logits.""" if not torch.cuda.is_available(): pytest.skip("CUDA not available") torch.manual_seed(42) B3, T3, V3 = 2, 16, V device = torch.device("cuda") targets = _random_targets(b=B3, t=T3).to(device) logits_bf16 = torch.randn(B3, T3, V3, device=device, dtype=torch.bfloat16, requires_grad=True) with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID) loss = mdlm_rb_loss(logits_bf16, targets, mask, weights) # Loss must be float32 assert loss.dtype == torch.float32, f"Expected float32 loss, got {loss.dtype}" # Backward must succeed and produce finite grads loss.backward() assert logits_bf16.grad is not None, "No gradient on logits" assert torch.isfinite(logits_bf16.grad).all(), "Inf/NaN in gradient" # --------------------------------------------------------------------------- # test_t_validation # --------------------------------------------------------------------------- def test_t_shape_error(): """Wrong t shape raises ValueError.""" targets = _random_targets() bad_t = torch.rand(B + 1) with pytest.raises(ValueError, match="shape"): mdlm_masked_forward_process(targets, MASK_ID, t=bad_t) def test_t_range_error(): """t outside [0, 1] raises ValueError.""" targets = _random_targets() bad_t = torch.rand(B) + 1.5 # all > 1 with pytest.raises(ValueError, match="\\[0, 1\\]"): mdlm_masked_forward_process(targets, MASK_ID, t=bad_t) # --------------------------------------------------------------------------- # test_weight_clamping # --------------------------------------------------------------------------- def test_weight_clamping(): """Loss weights capped at _MAX_WEIGHT even when t → 1 (alpha_t → 0).""" targets = _random_targets() # t very close to 1 → alpha_t very close to 0 t = torch.full((B,), 1.0 - 1e-9) x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID, t=t) assert (weights <= _MAX_WEIGHT + 1e-6).all(), ( f"Weight exceeded _MAX_WEIGHT={_MAX_WEIGHT}; max={weights.max().item()}" ) # --------------------------------------------------------------------------- # test_convenience_wrapper # --------------------------------------------------------------------------- def test_mdlm_loss_convenience(): """mdlm_loss end-to-end returns a scalar float32 loss.""" torch.manual_seed(0) targets = _random_targets() logits = _random_logits() loss = mdlm_loss(logits, targets, MASK_ID) assert loss.ndim == 0, "Expected scalar loss" assert loss.dtype == torch.float32 assert torch.isfinite(loss), f"Non-finite loss: {loss.item()}" def test_mdlm_loss_no_side_effects(): """mdlm_loss does not mutate targets or logits tensors.""" targets = _random_targets() logits = _random_logits() targets_copy = targets.clone() logits_copy = logits.clone() _ = mdlm_loss(logits, targets, MASK_ID) assert (targets == targets_copy).all(), "targets was mutated" assert (logits == logits_copy).all(), "logits was mutated" # --------------------------------------------------------------------------- # test_alpha_schedule_unknown # --------------------------------------------------------------------------- def test_alpha_schedule_unknown(): """Unknown alpha_schedule raises ValueError.""" targets = _random_targets() with pytest.raises(ValueError, match="Unknown alpha_schedule"): mdlm_masked_forward_process(targets, MASK_ID, alpha_schedule="cosine") # type: ignore