syedmohaiminulhoque's picture
Complete CDD implementation: Constrained Discrete Diffusion (arXiv:2503.09790v3)
2d0a056 verified
"""
Unit tests for CDD implementation.
Tests core components without requiring GPU or large model downloads:
1. Noise schedule correctness
2. Posterior computation (MDLM and UDLM)
3. Gumbel-Softmax relaxation
4. ALM projection convergence
5. Simplex projection
6. Constraint function interfaces
"""
import sys
import torch
import torch.nn.functional as F
import numpy as np
# Add parent to path
sys.path.insert(0, '/app')
def test_noise_schedule():
"""Test log-linear noise schedule properties."""
from cdd.utils.noise_schedule import log_linear_schedule
t = torch.linspace(0, 1, 100)
alpha = log_linear_schedule(t)
# α should be decreasing
assert torch.all(alpha[:-1] >= alpha[1:]), "Alpha should be monotonically decreasing"
# α(0) ≈ 1.0
assert alpha[0] > 0.99, f"Alpha at t=0 should be ≈1.0, got {alpha[0]:.4f}"
# α(1) ≈ eps
assert alpha[-1] < 0.001, f"Alpha at t=1 should be ≈0, got {alpha[-1]:.6f}"
print("✓ Noise schedule test passed")
def test_simplex_projection():
"""Test simplex projection correctness."""
from cdd.utils.noise_schedule import project_to_simplex
# Random input
x = torch.randn(2, 5, 10)
projected = project_to_simplex(x, dim=-1)
# Should be non-negative
assert torch.all(projected >= -1e-6), "Projected values should be non-negative"
# Should sum to 1
sums = projected.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \
f"Projected should sum to 1, got sums in [{sums.min():.4f}, {sums.max():.4f}]"
# Already on simplex should not change much
y = F.softmax(torch.randn(3, 10), dim=-1)
y_proj = project_to_simplex(y, dim=-1)
assert torch.allclose(y, y_proj, atol=1e-4), "Simplex input should be unchanged"
print("✓ Simplex projection test passed")
def test_mdlm_posterior():
"""Test MDLM posterior computation."""
from cdd.utils.noise_schedule import mdlm_posterior, log_linear_schedule
B, L, V = 2, 8, 20
mask_id = V - 1
# Create mixed masked/unmasked sequence
z_t = torch.randint(0, V-1, (B, L))
z_t[:, 3:6] = mask_id # Mask positions 3-5
x_theta = F.softmax(torch.randn(B, L, V), dim=-1)
t = torch.tensor([0.5, 0.5])
s = torch.tensor([0.4, 0.4])
alpha_t = log_linear_schedule(t)
alpha_s = log_linear_schedule(s)
posterior = mdlm_posterior(z_t, x_theta, alpha_t, alpha_s, mask_id)
# Shape check
assert posterior.shape == (B, L, V), f"Expected {(B,L,V)}, got {posterior.shape}"
# Should be valid probability distributions
sums = posterior.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \
f"Posterior should sum to 1, got {sums}"
# Non-negative
assert torch.all(posterior >= -1e-6), "Posterior should be non-negative"
# Unmasked positions should be one-hot at current token
for b in range(B):
for pos in [0, 1, 2, 6, 7]: # Unmasked positions
token = z_t[b, pos].item()
assert posterior[b, pos, token] > 0.99, \
f"Unmasked position should be one-hot, got {posterior[b, pos, token]:.4f}"
print("✓ MDLM posterior test passed")
def test_udlm_posterior():
"""Test UDLM posterior computation."""
from cdd.utils.noise_schedule import udlm_posterior, log_linear_schedule
B, L, V = 2, 8, 20
z_t = torch.randint(0, V, (B, L))
x_theta = F.softmax(torch.randn(B, L, V), dim=-1)
t = torch.tensor([0.5, 0.5])
s = torch.tensor([0.4, 0.4])
alpha_t = log_linear_schedule(t)
alpha_s = log_linear_schedule(s)
posterior = udlm_posterior(z_t, x_theta, alpha_t, alpha_s)
# Shape check
assert posterior.shape == (B, L, V)
# Valid probability distribution
sums = posterior.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \
f"UDLM posterior should sum to 1, got sums: min={sums.min():.4f}, max={sums.max():.4f}"
# Non-negative
assert torch.all(posterior >= -1e-6), "UDLM posterior should be non-negative"
print("✓ UDLM posterior test passed")
def test_gumbel_softmax():
"""Test Gumbel-Softmax relaxation properties."""
from cdd.samplers.cdd_sampler import gumbel_softmax
# Create peaked distribution
logits = torch.zeros(2, 5, 10)
logits[:, :, 3] = 10.0 # Peak at index 3
log_probs = F.log_softmax(logits, dim=-1)
# Low temperature should approximate argmax
result = gumbel_softmax(log_probs, temperature=0.01, use_noise=False)
argmax_vals = result.argmax(dim=-1)
assert torch.all(argmax_vals == 3), \
"Low temperature Gumbel-Softmax should approximate argmax"
# High temperature should be more uniform
result_high = gumbel_softmax(log_probs, temperature=100.0, use_noise=False)
max_probs = result_high.max(dim=-1).values
assert torch.all(max_probs < 0.5), \
"High temperature should give more uniform distribution"
# Output should be valid probabilities
assert torch.all(result >= 0), "Output should be non-negative"
sums = result.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \
"Output should sum to 1"
print("✓ Gumbel-Softmax test passed")
def test_alm_projection():
"""Test ALM projection converges to satisfy a simple constraint."""
from cdd.samplers.cdd_sampler import alm_projection, ALMConfig
B, L, V = 1, 4, 10
# Create input distribution
x_theta = F.softmax(torch.randn(B, L, V), dim=-1)
# Simple constraint: argmax of position 0 should be token 5
# g(x̃) = 1 - x̃[0, 0, 5] (violation if token 5 doesn't have high prob)
def constraint_fn(x_tilde):
prob_target = x_tilde[0, 0, 5]
return F.relu(0.9 - prob_target) # Want prob > 0.9
config = ALMConfig(
lambda_init=0.0,
mu_init=1.0,
mu_max=100.0,
outer_iter_max=50,
inner_iter_max=5,
eta=0.5,
gumbel_temperature=0.1,
use_gumbel_noise=False, # Deterministic for testing
)
x_proj = alm_projection(x_theta, constraint_fn, config)
# Check shape preserved
assert x_proj.shape == x_theta.shape, "Shape should be preserved"
# Check valid probabilities
assert torch.all(x_proj >= 0), "Projected should be non-negative"
sums = x_proj.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=0.1), \
f"Projected should approximately sum to 1, got {sums}"
# Check constraint is better satisfied
initial_violation = constraint_fn(x_theta).item()
final_violation = constraint_fn(x_proj).item()
assert final_violation <= initial_violation + 0.1, \
f"Projection should reduce violation: {initial_violation:.4f}{final_violation:.4f}"
print(f"✓ ALM projection test passed (violation: {initial_violation:.4f}{final_violation:.4f})")
def test_mdlm_forward_sample():
"""Test MDLM forward noising process."""
from cdd.utils.noise_schedule import mdlm_forward_sample
B, L = 4, 16
V = 50
mask_id = V - 1
x_0 = torch.randint(0, V-1, (B, L)) # Clean tokens (no mask tokens)
# At t=0 (clean): almost no masking
t_clean = torch.full((B,), 0.01)
z_clean = mdlm_forward_sample(x_0, t_clean, mask_id, V)
n_masked_clean = (z_clean == mask_id).sum().item()
# At t=1 (noisy): almost everything masked
t_noisy = torch.full((B,), 0.99)
z_noisy = mdlm_forward_sample(x_0, t_noisy, mask_id, V)
n_masked_noisy = (z_noisy == mask_id).sum().item()
assert n_masked_noisy > n_masked_clean, \
f"More masking at t=1 ({n_masked_noisy}) vs t=0 ({n_masked_clean})"
print(f"✓ MDLM forward sample test passed "
f"(masked: t≈0: {n_masked_clean}, t≈1: {n_masked_noisy})")
def test_counting_constraint():
"""Test counting constraint logic."""
from cdd.constraints.instruction import CountingConstraint
# Mock tokenizer
class MockTokenizer:
def encode(self, text, add_special_tokens=False):
# Map single digits to token ids 100+digit
if text.isdigit():
return [100 + int(text)]
return [ord(c) for c in text]
tokenizer = MockTokenizer()
# "How many 's' in 'mississippi'?" → answer is 4
constraint = CountingConstraint(tokenizer, "mississippi", "s")
assert constraint.correct_count == 4, \
f"Expected 4 's' in 'mississippi', got {constraint.correct_count}"
# "How many 'z' in 'hello'?" → answer is 0
constraint2 = CountingConstraint(tokenizer, "hello", "z")
assert constraint2.correct_count == 0, \
f"Expected 0 'z' in 'hello', got {constraint2.correct_count}"
# "How many 'l' in 'hello'?" → answer is 2
constraint3 = CountingConstraint(tokenizer, "hello", "l")
assert constraint3.correct_count == 2
print("✓ Counting constraint test passed")
def test_full_sampling_mock():
"""Test the full sampling loop with a mock model."""
from cdd.samplers.cdd_sampler import CDDSampler, ALMConfig
B, L, V = 1, 8, 20
# Mock model that returns random logits
class MockModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.config = type('Config', (), {
'vocab_size': V,
'model_length': L,
'hidden_dim': 64,
})()
self.backbone = type('Backbone', (), {
'vocab_embed': type('Embed', (), {
'embedding': torch.randn(V, 64)
})()
})()
def forward(self, input_ids=None, timesteps=None, **kwargs):
B = input_ids.shape[0]
logits = torch.randn(B, L, V)
return type('Output', (), {'logits': logits})()
class MockTokenizer:
mask_token_id = V - 1
def decode(self, ids, skip_special_tokens=True):
return "mock output text"
def encode(self, text, add_special_tokens=False, max_length=None, truncation=False):
return [1, 2, 3]
model = MockModel()
tokenizer = MockTokenizer()
# No constraint
def no_constraint(x):
return torch.tensor(0.0)
config = ALMConfig(
outer_iter_max=2,
inner_iter_max=2,
eta=0.1,
)
sampler = CDDSampler(
model=model,
tokenizer=tokenizer,
constraint_fn=no_constraint,
alm_config=config,
diffusion_type="mdlm",
num_timesteps=5, # Very few steps for testing
seq_length=L,
device="cpu",
)
result = sampler.sample(batch_size=1, apply_constraints=False)
assert "sequences" in result
assert "text" in result
assert result["sequences"].shape == (1, L)
assert len(result["text"]) == 1
# Test with constraints
result_constrained = sampler.sample(batch_size=1, apply_constraints=True)
assert result_constrained["sequences"].shape == (1, L)
# Test with prefix
prefix = torch.tensor([[1, 2, 3]], dtype=torch.long)
result_prefix = sampler.sample(
batch_size=1, prefix_ids=prefix, apply_constraints=False,
)
assert torch.all(result_prefix["sequences"][0, :3] == prefix[0])
print("✓ Full sampling loop test passed")
def run_all_tests():
"""Run all unit tests."""
print("=" * 60)
print("CDD Implementation Unit Tests")
print("=" * 60)
print()
tests = [
test_noise_schedule,
test_simplex_projection,
test_mdlm_posterior,
test_udlm_posterior,
test_gumbel_softmax,
test_alm_projection,
test_mdlm_forward_sample,
test_counting_constraint,
test_full_sampling_mock,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"✗ {test.__name__} FAILED: {e}")
import traceback
traceback.print_exc()
failed += 1
print()
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed out of {len(tests)}")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)