feather-runtime / overlay /tests /test_diffusion_loss.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
"""Tests for hydra/diffusion_loss.py β€” MDLM Rao-Blackwellized loss.
Paper: Sahoo et al., "Simple and Effective Masked Diffusion Language Models"
arXiv:2406.07524, NeurIPS 2024.
"""
from __future__ import annotations
import importlib.util
import math
import sys
from pathlib import Path
import pytest
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Import diffusion_loss directly from the file to avoid triggering
# hydra/__init__.py, which eagerly imports mamba_ssm (not available in the
# test environment without a GPU build). diffusion_loss.py has zero heavy deps.
# ---------------------------------------------------------------------------
_MODULE_PATH = Path(__file__).parent.parent / "hydra" / "diffusion_loss.py"
_spec = importlib.util.spec_from_file_location("hydra.diffusion_loss", _MODULE_PATH)
_diffusion_loss_mod = importlib.util.module_from_spec(_spec) # type: ignore[arg-type]
sys.modules["hydra.diffusion_loss"] = _diffusion_loss_mod
_spec.loader.exec_module(_diffusion_loss_mod) # type: ignore[union-attr]
_MAX_WEIGHT = _diffusion_loss_mod._MAX_WEIGHT
_MIN_ALPHA = _diffusion_loss_mod._MIN_ALPHA
mdlm_masked_forward_process = _diffusion_loss_mod.mdlm_masked_forward_process
mdlm_rb_loss = _diffusion_loss_mod.mdlm_rb_loss
mdlm_loss = _diffusion_loss_mod.mdlm_loss
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
B, T, V = 4, 32, 512
MASK_ID = 0
def _random_targets(b=B, t=T, v=V) -> torch.Tensor:
"""Random token ids in [1, V) so MASK_ID=0 is unambiguously special."""
return torch.randint(1, v, (b, t))
def _random_logits(b=B, t=T, v=V) -> torch.Tensor:
return torch.randn(b, t, v)
# ---------------------------------------------------------------------------
# test_forward_process_shape
# ---------------------------------------------------------------------------
def test_forward_process_shape():
"""x_t, mask_positions, loss_weights all have shape (B, T) with correct dtypes."""
targets = _random_targets()
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)
assert x_t.shape == (B, T), f"x_t shape: {x_t.shape}"
assert mask.shape == (B, T), f"mask shape: {mask.shape}"
assert weights.shape == (B, T), f"weights shape: {weights.shape}"
assert x_t.dtype == torch.int64, f"x_t dtype: {x_t.dtype}"
assert mask.dtype == torch.bool, f"mask dtype: {mask.dtype}"
assert weights.dtype == torch.float32, f"weights dtype: {weights.dtype}"
def test_forward_process_values_consistent():
"""Masked positions get mask_token_id; unmasked positions keep original."""
targets = _random_targets()
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)
# Masked β†’ mask token id
assert (x_t[mask] == MASK_ID).all(), "Masked positions should equal MASK_ID"
# Unmasked β†’ original token
assert (x_t[~mask] == targets[~mask]).all(), "Unmasked positions should equal original"
# Weights non-zero only on masked positions
assert (weights[~mask] == 0.0).all(), "Weights on unmasked positions should be 0"
assert (weights[mask] > 0.0).all(), "Weights on masked positions should be > 0"
# ---------------------------------------------------------------------------
# test_mask_fraction
# ---------------------------------------------------------------------------
def test_mask_fraction():
"""Mean mask fraction over many samples approximates mean(t) = 0.5."""
torch.manual_seed(42)
n_trials = 2000
total_mask = 0
total_tokens = 0
for _ in range(n_trials):
targets = _random_targets(b=4, t=16)
x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID)
total_mask += mask.float().sum().item()
total_tokens += mask.numel()
empirical_frac = total_mask / total_tokens
# Expected: E[mask_fraction] = E[1 - alpha_t] = E[t] = 0.5
# With n_trials=2000 and B*T=64, std β‰ˆ 0.5/sqrt(n_trials*B*T) β‰ˆ 0.0014
# Tolerance = 4 std β‰ˆ 0.006
assert abs(empirical_frac - 0.5) < 0.01, (
f"Expected mask fraction β‰ˆ 0.5, got {empirical_frac:.4f}"
)
def test_mask_fraction_with_fixed_t():
"""With fixed t=0.3, mask fraction β‰ˆ 0.3 (i.e., 1 - alpha_t = 1 - 0.7 = 0.3)."""
torch.manual_seed(7)
n_trials = 1000
t_val = 0.3
total_mask = 0
total_tokens = 0
for _ in range(n_trials):
targets = _random_targets(b=4, t=32)
t = torch.full((4,), t_val)
x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID, t=t)
total_mask += mask.float().sum().item()
total_tokens += mask.numel()
empirical_frac = total_mask / total_tokens
assert abs(empirical_frac - t_val) < 0.02, (
f"Expected mask fraction β‰ˆ {t_val}, got {empirical_frac:.4f}"
)
# ---------------------------------------------------------------------------
# test_unmasked_loss_zero
# ---------------------------------------------------------------------------
def test_unmasked_loss_zero():
"""When no positions are masked, rb_loss returns exactly 0."""
targets = _random_targets()
logits = _random_logits()
# Force mask_positions = all False and weights = 0
mask_positions = torch.zeros(B, T, dtype=torch.bool)
loss_weights = torch.zeros(B, T)
loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights)
assert loss.item() == pytest.approx(0.0, abs=1e-6), (
f"Expected 0.0 when nothing is masked, got {loss.item()}"
)
# ---------------------------------------------------------------------------
# test_loss_scales_with_weight
# ---------------------------------------------------------------------------
def test_loss_scales_with_weight():
"""Doubling loss_weights doubles the loss (linearity)."""
torch.manual_seed(1234)
targets = _random_targets()
logits = _random_logits()
# Fix a mask (at least some positions must be True).
mask_positions = torch.rand(B, T) < 0.5
if not mask_positions.any():
mask_positions[0, 0] = True
base_weights = torch.rand(B, T).float() * mask_positions.float()
loss1 = mdlm_rb_loss(logits, targets, mask_positions, base_weights)
loss2 = mdlm_rb_loss(logits, targets, mask_positions, base_weights * 2.0)
assert loss2.item() == pytest.approx(loss1.item() * 2.0, rel=1e-5), (
f"Expected 2x scaling: {loss1.item():.6f} * 2 β‰  {loss2.item():.6f}"
)
# ---------------------------------------------------------------------------
# test_ce_matches_reference
# ---------------------------------------------------------------------------
def test_ce_matches_reference():
"""On a tiny deterministic case, compare against manual numpy CE."""
torch.manual_seed(99)
B2, T2, V2 = 2, 4, 8
targets = torch.tensor([[1, 2, 3, 1], [2, 3, 0, 1]]) # NOTE: token 0 = MASK_ID
# Actually use targets without MASK_ID so they are all "real" tokens
targets = torch.tensor([[1, 2, 3, 4], [2, 3, 5, 6]])
# Fixed logits (all zeros β†’ uniform distribution β†’ CE = log(V))
logits = torch.zeros(B2, T2, V2)
# Fixed mask: mask positions (0,0), (0,2), (1,1), (1,3)
mask_positions = torch.tensor([
[True, False, True, False],
[False, True, False, True],
])
# Fixed alpha_t: row 0 β†’ alpha=0.5, row 1 β†’ alpha=0.25
# Loss weights: row 0 β†’ 1/0.5=2 on masked, row 1 β†’ 1/0.25=4 on masked
alpha = torch.tensor([0.5, 0.25])
loss_weights = torch.zeros(B2, T2)
for i in range(B2):
for j in range(T2):
if mask_positions[i, j]:
loss_weights[i, j] = 1.0 / alpha[i].item()
loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights)
# Manual reference via numpy:
# CE(uniform over V2=8) = log(8) = ln(8)
ce_ref = math.log(V2)
# Row 0: 2 masked positions, each weight=2, CE=ln(8)
# weighted_sum = 2 * 2.0 * ln(8)
# per_sample = (2 * 2.0 * ln(8)) / 2 = 2.0 * ln(8)
row0_loss = 2.0 * ce_ref
# Row 1: 2 masked positions, each weight=4, CE=ln(8)
# weighted_sum = 2 * 4.0 * ln(8)
# per_sample = (2 * 4.0 * ln(8)) / 2 = 4.0 * ln(8)
row1_loss = 4.0 * ce_ref
expected = (row0_loss + row1_loss) / 2.0
assert loss.item() == pytest.approx(expected, rel=1e-4), (
f"Expected {expected:.6f}, got {loss.item():.6f}"
)
# ---------------------------------------------------------------------------
# test_autograd_bf16
# ---------------------------------------------------------------------------
def test_autograd_bf16():
"""Loss is fp32 and backward produces finite grads even with bf16 logits."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
torch.manual_seed(42)
B3, T3, V3 = 2, 16, V
device = torch.device("cuda")
targets = _random_targets(b=B3, t=T3).to(device)
logits_bf16 = torch.randn(B3, T3, V3, device=device, dtype=torch.bfloat16,
requires_grad=True)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)
loss = mdlm_rb_loss(logits_bf16, targets, mask, weights)
# Loss must be float32
assert loss.dtype == torch.float32, f"Expected float32 loss, got {loss.dtype}"
# Backward must succeed and produce finite grads
loss.backward()
assert logits_bf16.grad is not None, "No gradient on logits"
assert torch.isfinite(logits_bf16.grad).all(), "Inf/NaN in gradient"
# ---------------------------------------------------------------------------
# test_t_validation
# ---------------------------------------------------------------------------
def test_t_shape_error():
"""Wrong t shape raises ValueError."""
targets = _random_targets()
bad_t = torch.rand(B + 1)
with pytest.raises(ValueError, match="shape"):
mdlm_masked_forward_process(targets, MASK_ID, t=bad_t)
def test_t_range_error():
"""t outside [0, 1] raises ValueError."""
targets = _random_targets()
bad_t = torch.rand(B) + 1.5 # all > 1
with pytest.raises(ValueError, match="\\[0, 1\\]"):
mdlm_masked_forward_process(targets, MASK_ID, t=bad_t)
# ---------------------------------------------------------------------------
# test_weight_clamping
# ---------------------------------------------------------------------------
def test_weight_clamping():
"""Loss weights capped at _MAX_WEIGHT even when t β†’ 1 (alpha_t β†’ 0)."""
targets = _random_targets()
# t very close to 1 β†’ alpha_t very close to 0
t = torch.full((B,), 1.0 - 1e-9)
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID, t=t)
assert (weights <= _MAX_WEIGHT + 1e-6).all(), (
f"Weight exceeded _MAX_WEIGHT={_MAX_WEIGHT}; max={weights.max().item()}"
)
# ---------------------------------------------------------------------------
# test_convenience_wrapper
# ---------------------------------------------------------------------------
def test_mdlm_loss_convenience():
"""mdlm_loss end-to-end returns a scalar float32 loss."""
torch.manual_seed(0)
targets = _random_targets()
logits = _random_logits()
loss = mdlm_loss(logits, targets, MASK_ID)
assert loss.ndim == 0, "Expected scalar loss"
assert loss.dtype == torch.float32
assert torch.isfinite(loss), f"Non-finite loss: {loss.item()}"
def test_mdlm_loss_no_side_effects():
"""mdlm_loss does not mutate targets or logits tensors."""
targets = _random_targets()
logits = _random_logits()
targets_copy = targets.clone()
logits_copy = logits.clone()
_ = mdlm_loss(logits, targets, MASK_ID)
assert (targets == targets_copy).all(), "targets was mutated"
assert (logits == logits_copy).all(), "logits was mutated"
# ---------------------------------------------------------------------------
# test_alpha_schedule_unknown
# ---------------------------------------------------------------------------
def test_alpha_schedule_unknown():
"""Unknown alpha_schedule raises ValueError."""
targets = _random_targets()
with pytest.raises(ValueError, match="Unknown alpha_schedule"):
mdlm_masked_forward_process(targets, MASK_ID, alpha_schedule="cosine") # type: ignore