Spaces:
Runtime error
Runtime error
| """Tests for hydra/diffusion_loss.py β MDLM Rao-Blackwellized loss. | |
| Paper: Sahoo et al., "Simple and Effective Masked Diffusion Language Models" | |
| arXiv:2406.07524, NeurIPS 2024. | |
| """ | |
| from __future__ import annotations | |
| import importlib.util | |
| import math | |
| import sys | |
| from pathlib import Path | |
| import pytest | |
| import torch | |
| import torch.nn.functional as F | |
| # --------------------------------------------------------------------------- | |
| # Import diffusion_loss directly from the file to avoid triggering | |
| # hydra/__init__.py, which eagerly imports mamba_ssm (not available in the | |
| # test environment without a GPU build). diffusion_loss.py has zero heavy deps. | |
| # --------------------------------------------------------------------------- | |
| _MODULE_PATH = Path(__file__).parent.parent / "hydra" / "diffusion_loss.py" | |
| _spec = importlib.util.spec_from_file_location("hydra.diffusion_loss", _MODULE_PATH) | |
| _diffusion_loss_mod = importlib.util.module_from_spec(_spec) # type: ignore[arg-type] | |
| sys.modules["hydra.diffusion_loss"] = _diffusion_loss_mod | |
| _spec.loader.exec_module(_diffusion_loss_mod) # type: ignore[union-attr] | |
| _MAX_WEIGHT = _diffusion_loss_mod._MAX_WEIGHT | |
| _MIN_ALPHA = _diffusion_loss_mod._MIN_ALPHA | |
| mdlm_masked_forward_process = _diffusion_loss_mod.mdlm_masked_forward_process | |
| mdlm_rb_loss = _diffusion_loss_mod.mdlm_rb_loss | |
| mdlm_loss = _diffusion_loss_mod.mdlm_loss | |
| # --------------------------------------------------------------------------- | |
| # Fixtures / helpers | |
| # --------------------------------------------------------------------------- | |
| B, T, V = 4, 32, 512 | |
| MASK_ID = 0 | |
| def _random_targets(b=B, t=T, v=V) -> torch.Tensor: | |
| """Random token ids in [1, V) so MASK_ID=0 is unambiguously special.""" | |
| return torch.randint(1, v, (b, t)) | |
| def _random_logits(b=B, t=T, v=V) -> torch.Tensor: | |
| return torch.randn(b, t, v) | |
| # --------------------------------------------------------------------------- | |
| # test_forward_process_shape | |
| # --------------------------------------------------------------------------- | |
| def test_forward_process_shape(): | |
| """x_t, mask_positions, loss_weights all have shape (B, T) with correct dtypes.""" | |
| targets = _random_targets() | |
| x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID) | |
| assert x_t.shape == (B, T), f"x_t shape: {x_t.shape}" | |
| assert mask.shape == (B, T), f"mask shape: {mask.shape}" | |
| assert weights.shape == (B, T), f"weights shape: {weights.shape}" | |
| assert x_t.dtype == torch.int64, f"x_t dtype: {x_t.dtype}" | |
| assert mask.dtype == torch.bool, f"mask dtype: {mask.dtype}" | |
| assert weights.dtype == torch.float32, f"weights dtype: {weights.dtype}" | |
| def test_forward_process_values_consistent(): | |
| """Masked positions get mask_token_id; unmasked positions keep original.""" | |
| targets = _random_targets() | |
| x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID) | |
| # Masked β mask token id | |
| assert (x_t[mask] == MASK_ID).all(), "Masked positions should equal MASK_ID" | |
| # Unmasked β original token | |
| assert (x_t[~mask] == targets[~mask]).all(), "Unmasked positions should equal original" | |
| # Weights non-zero only on masked positions | |
| assert (weights[~mask] == 0.0).all(), "Weights on unmasked positions should be 0" | |
| assert (weights[mask] > 0.0).all(), "Weights on masked positions should be > 0" | |
| # --------------------------------------------------------------------------- | |
| # test_mask_fraction | |
| # --------------------------------------------------------------------------- | |
| def test_mask_fraction(): | |
| """Mean mask fraction over many samples approximates mean(t) = 0.5.""" | |
| torch.manual_seed(42) | |
| n_trials = 2000 | |
| total_mask = 0 | |
| total_tokens = 0 | |
| for _ in range(n_trials): | |
| targets = _random_targets(b=4, t=16) | |
| x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID) | |
| total_mask += mask.float().sum().item() | |
| total_tokens += mask.numel() | |
| empirical_frac = total_mask / total_tokens | |
| # Expected: E[mask_fraction] = E[1 - alpha_t] = E[t] = 0.5 | |
| # With n_trials=2000 and B*T=64, std β 0.5/sqrt(n_trials*B*T) β 0.0014 | |
| # Tolerance = 4 std β 0.006 | |
| assert abs(empirical_frac - 0.5) < 0.01, ( | |
| f"Expected mask fraction β 0.5, got {empirical_frac:.4f}" | |
| ) | |
| def test_mask_fraction_with_fixed_t(): | |
| """With fixed t=0.3, mask fraction β 0.3 (i.e., 1 - alpha_t = 1 - 0.7 = 0.3).""" | |
| torch.manual_seed(7) | |
| n_trials = 1000 | |
| t_val = 0.3 | |
| total_mask = 0 | |
| total_tokens = 0 | |
| for _ in range(n_trials): | |
| targets = _random_targets(b=4, t=32) | |
| t = torch.full((4,), t_val) | |
| x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID, t=t) | |
| total_mask += mask.float().sum().item() | |
| total_tokens += mask.numel() | |
| empirical_frac = total_mask / total_tokens | |
| assert abs(empirical_frac - t_val) < 0.02, ( | |
| f"Expected mask fraction β {t_val}, got {empirical_frac:.4f}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # test_unmasked_loss_zero | |
| # --------------------------------------------------------------------------- | |
| def test_unmasked_loss_zero(): | |
| """When no positions are masked, rb_loss returns exactly 0.""" | |
| targets = _random_targets() | |
| logits = _random_logits() | |
| # Force mask_positions = all False and weights = 0 | |
| mask_positions = torch.zeros(B, T, dtype=torch.bool) | |
| loss_weights = torch.zeros(B, T) | |
| loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights) | |
| assert loss.item() == pytest.approx(0.0, abs=1e-6), ( | |
| f"Expected 0.0 when nothing is masked, got {loss.item()}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # test_loss_scales_with_weight | |
| # --------------------------------------------------------------------------- | |
| def test_loss_scales_with_weight(): | |
| """Doubling loss_weights doubles the loss (linearity).""" | |
| torch.manual_seed(1234) | |
| targets = _random_targets() | |
| logits = _random_logits() | |
| # Fix a mask (at least some positions must be True). | |
| mask_positions = torch.rand(B, T) < 0.5 | |
| if not mask_positions.any(): | |
| mask_positions[0, 0] = True | |
| base_weights = torch.rand(B, T).float() * mask_positions.float() | |
| loss1 = mdlm_rb_loss(logits, targets, mask_positions, base_weights) | |
| loss2 = mdlm_rb_loss(logits, targets, mask_positions, base_weights * 2.0) | |
| assert loss2.item() == pytest.approx(loss1.item() * 2.0, rel=1e-5), ( | |
| f"Expected 2x scaling: {loss1.item():.6f} * 2 β {loss2.item():.6f}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # test_ce_matches_reference | |
| # --------------------------------------------------------------------------- | |
| def test_ce_matches_reference(): | |
| """On a tiny deterministic case, compare against manual numpy CE.""" | |
| torch.manual_seed(99) | |
| B2, T2, V2 = 2, 4, 8 | |
| targets = torch.tensor([[1, 2, 3, 1], [2, 3, 0, 1]]) # NOTE: token 0 = MASK_ID | |
| # Actually use targets without MASK_ID so they are all "real" tokens | |
| targets = torch.tensor([[1, 2, 3, 4], [2, 3, 5, 6]]) | |
| # Fixed logits (all zeros β uniform distribution β CE = log(V)) | |
| logits = torch.zeros(B2, T2, V2) | |
| # Fixed mask: mask positions (0,0), (0,2), (1,1), (1,3) | |
| mask_positions = torch.tensor([ | |
| [True, False, True, False], | |
| [False, True, False, True], | |
| ]) | |
| # Fixed alpha_t: row 0 β alpha=0.5, row 1 β alpha=0.25 | |
| # Loss weights: row 0 β 1/0.5=2 on masked, row 1 β 1/0.25=4 on masked | |
| alpha = torch.tensor([0.5, 0.25]) | |
| loss_weights = torch.zeros(B2, T2) | |
| for i in range(B2): | |
| for j in range(T2): | |
| if mask_positions[i, j]: | |
| loss_weights[i, j] = 1.0 / alpha[i].item() | |
| loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights) | |
| # Manual reference via numpy: | |
| # CE(uniform over V2=8) = log(8) = ln(8) | |
| ce_ref = math.log(V2) | |
| # Row 0: 2 masked positions, each weight=2, CE=ln(8) | |
| # weighted_sum = 2 * 2.0 * ln(8) | |
| # per_sample = (2 * 2.0 * ln(8)) / 2 = 2.0 * ln(8) | |
| row0_loss = 2.0 * ce_ref | |
| # Row 1: 2 masked positions, each weight=4, CE=ln(8) | |
| # weighted_sum = 2 * 4.0 * ln(8) | |
| # per_sample = (2 * 4.0 * ln(8)) / 2 = 4.0 * ln(8) | |
| row1_loss = 4.0 * ce_ref | |
| expected = (row0_loss + row1_loss) / 2.0 | |
| assert loss.item() == pytest.approx(expected, rel=1e-4), ( | |
| f"Expected {expected:.6f}, got {loss.item():.6f}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # test_autograd_bf16 | |
| # --------------------------------------------------------------------------- | |
| def test_autograd_bf16(): | |
| """Loss is fp32 and backward produces finite grads even with bf16 logits.""" | |
| if not torch.cuda.is_available(): | |
| pytest.skip("CUDA not available") | |
| torch.manual_seed(42) | |
| B3, T3, V3 = 2, 16, V | |
| device = torch.device("cuda") | |
| targets = _random_targets(b=B3, t=T3).to(device) | |
| logits_bf16 = torch.randn(B3, T3, V3, device=device, dtype=torch.bfloat16, | |
| requires_grad=True) | |
| with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
| x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID) | |
| loss = mdlm_rb_loss(logits_bf16, targets, mask, weights) | |
| # Loss must be float32 | |
| assert loss.dtype == torch.float32, f"Expected float32 loss, got {loss.dtype}" | |
| # Backward must succeed and produce finite grads | |
| loss.backward() | |
| assert logits_bf16.grad is not None, "No gradient on logits" | |
| assert torch.isfinite(logits_bf16.grad).all(), "Inf/NaN in gradient" | |
| # --------------------------------------------------------------------------- | |
| # test_t_validation | |
| # --------------------------------------------------------------------------- | |
| def test_t_shape_error(): | |
| """Wrong t shape raises ValueError.""" | |
| targets = _random_targets() | |
| bad_t = torch.rand(B + 1) | |
| with pytest.raises(ValueError, match="shape"): | |
| mdlm_masked_forward_process(targets, MASK_ID, t=bad_t) | |
| def test_t_range_error(): | |
| """t outside [0, 1] raises ValueError.""" | |
| targets = _random_targets() | |
| bad_t = torch.rand(B) + 1.5 # all > 1 | |
| with pytest.raises(ValueError, match="\\[0, 1\\]"): | |
| mdlm_masked_forward_process(targets, MASK_ID, t=bad_t) | |
| # --------------------------------------------------------------------------- | |
| # test_weight_clamping | |
| # --------------------------------------------------------------------------- | |
| def test_weight_clamping(): | |
| """Loss weights capped at _MAX_WEIGHT even when t β 1 (alpha_t β 0).""" | |
| targets = _random_targets() | |
| # t very close to 1 β alpha_t very close to 0 | |
| t = torch.full((B,), 1.0 - 1e-9) | |
| x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID, t=t) | |
| assert (weights <= _MAX_WEIGHT + 1e-6).all(), ( | |
| f"Weight exceeded _MAX_WEIGHT={_MAX_WEIGHT}; max={weights.max().item()}" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # test_convenience_wrapper | |
| # --------------------------------------------------------------------------- | |
| def test_mdlm_loss_convenience(): | |
| """mdlm_loss end-to-end returns a scalar float32 loss.""" | |
| torch.manual_seed(0) | |
| targets = _random_targets() | |
| logits = _random_logits() | |
| loss = mdlm_loss(logits, targets, MASK_ID) | |
| assert loss.ndim == 0, "Expected scalar loss" | |
| assert loss.dtype == torch.float32 | |
| assert torch.isfinite(loss), f"Non-finite loss: {loss.item()}" | |
| def test_mdlm_loss_no_side_effects(): | |
| """mdlm_loss does not mutate targets or logits tensors.""" | |
| targets = _random_targets() | |
| logits = _random_logits() | |
| targets_copy = targets.clone() | |
| logits_copy = logits.clone() | |
| _ = mdlm_loss(logits, targets, MASK_ID) | |
| assert (targets == targets_copy).all(), "targets was mutated" | |
| assert (logits == logits_copy).all(), "logits was mutated" | |
| # --------------------------------------------------------------------------- | |
| # test_alpha_schedule_unknown | |
| # --------------------------------------------------------------------------- | |
| def test_alpha_schedule_unknown(): | |
| """Unknown alpha_schedule raises ValueError.""" | |
| targets = _random_targets() | |
| with pytest.raises(ValueError, match="Unknown alpha_schedule"): | |
| mdlm_masked_forward_process(targets, MASK_ID, alpha_schedule="cosine") # type: ignore | |