Spaces:
Runtime error
Runtime error
File size: 12,445 Bytes
e317e25 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | """Tests for hydra/diffusion_loss.py β MDLM Rao-Blackwellized loss.
Paper: Sahoo et al., "Simple and Effective Masked Diffusion Language Models"
arXiv:2406.07524, NeurIPS 2024.
"""
from __future__ import annotations
import importlib.util
import math
import sys
from pathlib import Path
import pytest
import torch
import torch.nn.functional as F
# ---------------------------------------------------------------------------
# Import diffusion_loss directly from the file to avoid triggering
# hydra/__init__.py, which eagerly imports mamba_ssm (not available in the
# test environment without a GPU build). diffusion_loss.py has zero heavy deps.
# ---------------------------------------------------------------------------
_MODULE_PATH = Path(__file__).parent.parent / "hydra" / "diffusion_loss.py"
_spec = importlib.util.spec_from_file_location("hydra.diffusion_loss", _MODULE_PATH)
_diffusion_loss_mod = importlib.util.module_from_spec(_spec) # type: ignore[arg-type]
sys.modules["hydra.diffusion_loss"] = _diffusion_loss_mod
_spec.loader.exec_module(_diffusion_loss_mod) # type: ignore[union-attr]
_MAX_WEIGHT = _diffusion_loss_mod._MAX_WEIGHT
_MIN_ALPHA = _diffusion_loss_mod._MIN_ALPHA
mdlm_masked_forward_process = _diffusion_loss_mod.mdlm_masked_forward_process
mdlm_rb_loss = _diffusion_loss_mod.mdlm_rb_loss
mdlm_loss = _diffusion_loss_mod.mdlm_loss
# ---------------------------------------------------------------------------
# Fixtures / helpers
# ---------------------------------------------------------------------------
B, T, V = 4, 32, 512
MASK_ID = 0
def _random_targets(b=B, t=T, v=V) -> torch.Tensor:
"""Random token ids in [1, V) so MASK_ID=0 is unambiguously special."""
return torch.randint(1, v, (b, t))
def _random_logits(b=B, t=T, v=V) -> torch.Tensor:
return torch.randn(b, t, v)
# ---------------------------------------------------------------------------
# test_forward_process_shape
# ---------------------------------------------------------------------------
def test_forward_process_shape():
"""x_t, mask_positions, loss_weights all have shape (B, T) with correct dtypes."""
targets = _random_targets()
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)
assert x_t.shape == (B, T), f"x_t shape: {x_t.shape}"
assert mask.shape == (B, T), f"mask shape: {mask.shape}"
assert weights.shape == (B, T), f"weights shape: {weights.shape}"
assert x_t.dtype == torch.int64, f"x_t dtype: {x_t.dtype}"
assert mask.dtype == torch.bool, f"mask dtype: {mask.dtype}"
assert weights.dtype == torch.float32, f"weights dtype: {weights.dtype}"
def test_forward_process_values_consistent():
"""Masked positions get mask_token_id; unmasked positions keep original."""
targets = _random_targets()
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)
# Masked β mask token id
assert (x_t[mask] == MASK_ID).all(), "Masked positions should equal MASK_ID"
# Unmasked β original token
assert (x_t[~mask] == targets[~mask]).all(), "Unmasked positions should equal original"
# Weights non-zero only on masked positions
assert (weights[~mask] == 0.0).all(), "Weights on unmasked positions should be 0"
assert (weights[mask] > 0.0).all(), "Weights on masked positions should be > 0"
# ---------------------------------------------------------------------------
# test_mask_fraction
# ---------------------------------------------------------------------------
def test_mask_fraction():
"""Mean mask fraction over many samples approximates mean(t) = 0.5."""
torch.manual_seed(42)
n_trials = 2000
total_mask = 0
total_tokens = 0
for _ in range(n_trials):
targets = _random_targets(b=4, t=16)
x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID)
total_mask += mask.float().sum().item()
total_tokens += mask.numel()
empirical_frac = total_mask / total_tokens
# Expected: E[mask_fraction] = E[1 - alpha_t] = E[t] = 0.5
# With n_trials=2000 and B*T=64, std β 0.5/sqrt(n_trials*B*T) β 0.0014
# Tolerance = 4 std β 0.006
assert abs(empirical_frac - 0.5) < 0.01, (
f"Expected mask fraction β 0.5, got {empirical_frac:.4f}"
)
def test_mask_fraction_with_fixed_t():
"""With fixed t=0.3, mask fraction β 0.3 (i.e., 1 - alpha_t = 1 - 0.7 = 0.3)."""
torch.manual_seed(7)
n_trials = 1000
t_val = 0.3
total_mask = 0
total_tokens = 0
for _ in range(n_trials):
targets = _random_targets(b=4, t=32)
t = torch.full((4,), t_val)
x_t, mask, _ = mdlm_masked_forward_process(targets, MASK_ID, t=t)
total_mask += mask.float().sum().item()
total_tokens += mask.numel()
empirical_frac = total_mask / total_tokens
assert abs(empirical_frac - t_val) < 0.02, (
f"Expected mask fraction β {t_val}, got {empirical_frac:.4f}"
)
# ---------------------------------------------------------------------------
# test_unmasked_loss_zero
# ---------------------------------------------------------------------------
def test_unmasked_loss_zero():
"""When no positions are masked, rb_loss returns exactly 0."""
targets = _random_targets()
logits = _random_logits()
# Force mask_positions = all False and weights = 0
mask_positions = torch.zeros(B, T, dtype=torch.bool)
loss_weights = torch.zeros(B, T)
loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights)
assert loss.item() == pytest.approx(0.0, abs=1e-6), (
f"Expected 0.0 when nothing is masked, got {loss.item()}"
)
# ---------------------------------------------------------------------------
# test_loss_scales_with_weight
# ---------------------------------------------------------------------------
def test_loss_scales_with_weight():
"""Doubling loss_weights doubles the loss (linearity)."""
torch.manual_seed(1234)
targets = _random_targets()
logits = _random_logits()
# Fix a mask (at least some positions must be True).
mask_positions = torch.rand(B, T) < 0.5
if not mask_positions.any():
mask_positions[0, 0] = True
base_weights = torch.rand(B, T).float() * mask_positions.float()
loss1 = mdlm_rb_loss(logits, targets, mask_positions, base_weights)
loss2 = mdlm_rb_loss(logits, targets, mask_positions, base_weights * 2.0)
assert loss2.item() == pytest.approx(loss1.item() * 2.0, rel=1e-5), (
f"Expected 2x scaling: {loss1.item():.6f} * 2 β {loss2.item():.6f}"
)
# ---------------------------------------------------------------------------
# test_ce_matches_reference
# ---------------------------------------------------------------------------
def test_ce_matches_reference():
"""On a tiny deterministic case, compare against manual numpy CE."""
torch.manual_seed(99)
B2, T2, V2 = 2, 4, 8
targets = torch.tensor([[1, 2, 3, 1], [2, 3, 0, 1]]) # NOTE: token 0 = MASK_ID
# Actually use targets without MASK_ID so they are all "real" tokens
targets = torch.tensor([[1, 2, 3, 4], [2, 3, 5, 6]])
# Fixed logits (all zeros β uniform distribution β CE = log(V))
logits = torch.zeros(B2, T2, V2)
# Fixed mask: mask positions (0,0), (0,2), (1,1), (1,3)
mask_positions = torch.tensor([
[True, False, True, False],
[False, True, False, True],
])
# Fixed alpha_t: row 0 β alpha=0.5, row 1 β alpha=0.25
# Loss weights: row 0 β 1/0.5=2 on masked, row 1 β 1/0.25=4 on masked
alpha = torch.tensor([0.5, 0.25])
loss_weights = torch.zeros(B2, T2)
for i in range(B2):
for j in range(T2):
if mask_positions[i, j]:
loss_weights[i, j] = 1.0 / alpha[i].item()
loss = mdlm_rb_loss(logits, targets, mask_positions, loss_weights)
# Manual reference via numpy:
# CE(uniform over V2=8) = log(8) = ln(8)
ce_ref = math.log(V2)
# Row 0: 2 masked positions, each weight=2, CE=ln(8)
# weighted_sum = 2 * 2.0 * ln(8)
# per_sample = (2 * 2.0 * ln(8)) / 2 = 2.0 * ln(8)
row0_loss = 2.0 * ce_ref
# Row 1: 2 masked positions, each weight=4, CE=ln(8)
# weighted_sum = 2 * 4.0 * ln(8)
# per_sample = (2 * 4.0 * ln(8)) / 2 = 4.0 * ln(8)
row1_loss = 4.0 * ce_ref
expected = (row0_loss + row1_loss) / 2.0
assert loss.item() == pytest.approx(expected, rel=1e-4), (
f"Expected {expected:.6f}, got {loss.item():.6f}"
)
# ---------------------------------------------------------------------------
# test_autograd_bf16
# ---------------------------------------------------------------------------
def test_autograd_bf16():
"""Loss is fp32 and backward produces finite grads even with bf16 logits."""
if not torch.cuda.is_available():
pytest.skip("CUDA not available")
torch.manual_seed(42)
B3, T3, V3 = 2, 16, V
device = torch.device("cuda")
targets = _random_targets(b=B3, t=T3).to(device)
logits_bf16 = torch.randn(B3, T3, V3, device=device, dtype=torch.bfloat16,
requires_grad=True)
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID)
loss = mdlm_rb_loss(logits_bf16, targets, mask, weights)
# Loss must be float32
assert loss.dtype == torch.float32, f"Expected float32 loss, got {loss.dtype}"
# Backward must succeed and produce finite grads
loss.backward()
assert logits_bf16.grad is not None, "No gradient on logits"
assert torch.isfinite(logits_bf16.grad).all(), "Inf/NaN in gradient"
# ---------------------------------------------------------------------------
# test_t_validation
# ---------------------------------------------------------------------------
def test_t_shape_error():
"""Wrong t shape raises ValueError."""
targets = _random_targets()
bad_t = torch.rand(B + 1)
with pytest.raises(ValueError, match="shape"):
mdlm_masked_forward_process(targets, MASK_ID, t=bad_t)
def test_t_range_error():
"""t outside [0, 1] raises ValueError."""
targets = _random_targets()
bad_t = torch.rand(B) + 1.5 # all > 1
with pytest.raises(ValueError, match="\\[0, 1\\]"):
mdlm_masked_forward_process(targets, MASK_ID, t=bad_t)
# ---------------------------------------------------------------------------
# test_weight_clamping
# ---------------------------------------------------------------------------
def test_weight_clamping():
"""Loss weights capped at _MAX_WEIGHT even when t β 1 (alpha_t β 0)."""
targets = _random_targets()
# t very close to 1 β alpha_t very close to 0
t = torch.full((B,), 1.0 - 1e-9)
x_t, mask, weights = mdlm_masked_forward_process(targets, MASK_ID, t=t)
assert (weights <= _MAX_WEIGHT + 1e-6).all(), (
f"Weight exceeded _MAX_WEIGHT={_MAX_WEIGHT}; max={weights.max().item()}"
)
# ---------------------------------------------------------------------------
# test_convenience_wrapper
# ---------------------------------------------------------------------------
def test_mdlm_loss_convenience():
"""mdlm_loss end-to-end returns a scalar float32 loss."""
torch.manual_seed(0)
targets = _random_targets()
logits = _random_logits()
loss = mdlm_loss(logits, targets, MASK_ID)
assert loss.ndim == 0, "Expected scalar loss"
assert loss.dtype == torch.float32
assert torch.isfinite(loss), f"Non-finite loss: {loss.item()}"
def test_mdlm_loss_no_side_effects():
"""mdlm_loss does not mutate targets or logits tensors."""
targets = _random_targets()
logits = _random_logits()
targets_copy = targets.clone()
logits_copy = logits.clone()
_ = mdlm_loss(logits, targets, MASK_ID)
assert (targets == targets_copy).all(), "targets was mutated"
assert (logits == logits_copy).all(), "logits was mutated"
# ---------------------------------------------------------------------------
# test_alpha_schedule_unknown
# ---------------------------------------------------------------------------
def test_alpha_schedule_unknown():
"""Unknown alpha_schedule raises ValueError."""
targets = _random_targets()
with pytest.raises(ValueError, match="Unknown alpha_schedule"):
mdlm_masked_forward_process(targets, MASK_ID, alpha_schedule="cosine") # type: ignore
|