""" Unit tests for CDD implementation. Tests core components without requiring GPU or large model downloads: 1. Noise schedule correctness 2. Posterior computation (MDLM and UDLM) 3. Gumbel-Softmax relaxation 4. ALM projection convergence 5. Simplex projection 6. Constraint function interfaces """ import sys import torch import torch.nn.functional as F import numpy as np # Add parent to path sys.path.insert(0, '/app') def test_noise_schedule(): """Test log-linear noise schedule properties.""" from cdd.utils.noise_schedule import log_linear_schedule t = torch.linspace(0, 1, 100) alpha = log_linear_schedule(t) # α should be decreasing assert torch.all(alpha[:-1] >= alpha[1:]), "Alpha should be monotonically decreasing" # α(0) ≈ 1.0 assert alpha[0] > 0.99, f"Alpha at t=0 should be ≈1.0, got {alpha[0]:.4f}" # α(1) ≈ eps assert alpha[-1] < 0.001, f"Alpha at t=1 should be ≈0, got {alpha[-1]:.6f}" print("✓ Noise schedule test passed") def test_simplex_projection(): """Test simplex projection correctness.""" from cdd.utils.noise_schedule import project_to_simplex # Random input x = torch.randn(2, 5, 10) projected = project_to_simplex(x, dim=-1) # Should be non-negative assert torch.all(projected >= -1e-6), "Projected values should be non-negative" # Should sum to 1 sums = projected.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \ f"Projected should sum to 1, got sums in [{sums.min():.4f}, {sums.max():.4f}]" # Already on simplex should not change much y = F.softmax(torch.randn(3, 10), dim=-1) y_proj = project_to_simplex(y, dim=-1) assert torch.allclose(y, y_proj, atol=1e-4), "Simplex input should be unchanged" print("✓ Simplex projection test passed") def test_mdlm_posterior(): """Test MDLM posterior computation.""" from cdd.utils.noise_schedule import mdlm_posterior, log_linear_schedule B, L, V = 2, 8, 20 mask_id = V - 1 # Create mixed masked/unmasked sequence z_t = torch.randint(0, V-1, (B, L)) z_t[:, 3:6] = mask_id # Mask positions 3-5 x_theta = F.softmax(torch.randn(B, L, V), dim=-1) t = torch.tensor([0.5, 0.5]) s = torch.tensor([0.4, 0.4]) alpha_t = log_linear_schedule(t) alpha_s = log_linear_schedule(s) posterior = mdlm_posterior(z_t, x_theta, alpha_t, alpha_s, mask_id) # Shape check assert posterior.shape == (B, L, V), f"Expected {(B,L,V)}, got {posterior.shape}" # Should be valid probability distributions sums = posterior.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \ f"Posterior should sum to 1, got {sums}" # Non-negative assert torch.all(posterior >= -1e-6), "Posterior should be non-negative" # Unmasked positions should be one-hot at current token for b in range(B): for pos in [0, 1, 2, 6, 7]: # Unmasked positions token = z_t[b, pos].item() assert posterior[b, pos, token] > 0.99, \ f"Unmasked position should be one-hot, got {posterior[b, pos, token]:.4f}" print("✓ MDLM posterior test passed") def test_udlm_posterior(): """Test UDLM posterior computation.""" from cdd.utils.noise_schedule import udlm_posterior, log_linear_schedule B, L, V = 2, 8, 20 z_t = torch.randint(0, V, (B, L)) x_theta = F.softmax(torch.randn(B, L, V), dim=-1) t = torch.tensor([0.5, 0.5]) s = torch.tensor([0.4, 0.4]) alpha_t = log_linear_schedule(t) alpha_s = log_linear_schedule(s) posterior = udlm_posterior(z_t, x_theta, alpha_t, alpha_s) # Shape check assert posterior.shape == (B, L, V) # Valid probability distribution sums = posterior.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \ f"UDLM posterior should sum to 1, got sums: min={sums.min():.4f}, max={sums.max():.4f}" # Non-negative assert torch.all(posterior >= -1e-6), "UDLM posterior should be non-negative" print("✓ UDLM posterior test passed") def test_gumbel_softmax(): """Test Gumbel-Softmax relaxation properties.""" from cdd.samplers.cdd_sampler import gumbel_softmax # Create peaked distribution logits = torch.zeros(2, 5, 10) logits[:, :, 3] = 10.0 # Peak at index 3 log_probs = F.log_softmax(logits, dim=-1) # Low temperature should approximate argmax result = gumbel_softmax(log_probs, temperature=0.01, use_noise=False) argmax_vals = result.argmax(dim=-1) assert torch.all(argmax_vals == 3), \ "Low temperature Gumbel-Softmax should approximate argmax" # High temperature should be more uniform result_high = gumbel_softmax(log_probs, temperature=100.0, use_noise=False) max_probs = result_high.max(dim=-1).values assert torch.all(max_probs < 0.5), \ "High temperature should give more uniform distribution" # Output should be valid probabilities assert torch.all(result >= 0), "Output should be non-negative" sums = result.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \ "Output should sum to 1" print("✓ Gumbel-Softmax test passed") def test_alm_projection(): """Test ALM projection converges to satisfy a simple constraint.""" from cdd.samplers.cdd_sampler import alm_projection, ALMConfig B, L, V = 1, 4, 10 # Create input distribution x_theta = F.softmax(torch.randn(B, L, V), dim=-1) # Simple constraint: argmax of position 0 should be token 5 # g(x̃) = 1 - x̃[0, 0, 5] (violation if token 5 doesn't have high prob) def constraint_fn(x_tilde): prob_target = x_tilde[0, 0, 5] return F.relu(0.9 - prob_target) # Want prob > 0.9 config = ALMConfig( lambda_init=0.0, mu_init=1.0, mu_max=100.0, outer_iter_max=50, inner_iter_max=5, eta=0.5, gumbel_temperature=0.1, use_gumbel_noise=False, # Deterministic for testing ) x_proj = alm_projection(x_theta, constraint_fn, config) # Check shape preserved assert x_proj.shape == x_theta.shape, "Shape should be preserved" # Check valid probabilities assert torch.all(x_proj >= 0), "Projected should be non-negative" sums = x_proj.sum(dim=-1) assert torch.allclose(sums, torch.ones_like(sums), atol=0.1), \ f"Projected should approximately sum to 1, got {sums}" # Check constraint is better satisfied initial_violation = constraint_fn(x_theta).item() final_violation = constraint_fn(x_proj).item() assert final_violation <= initial_violation + 0.1, \ f"Projection should reduce violation: {initial_violation:.4f} → {final_violation:.4f}" print(f"✓ ALM projection test passed (violation: {initial_violation:.4f} → {final_violation:.4f})") def test_mdlm_forward_sample(): """Test MDLM forward noising process.""" from cdd.utils.noise_schedule import mdlm_forward_sample B, L = 4, 16 V = 50 mask_id = V - 1 x_0 = torch.randint(0, V-1, (B, L)) # Clean tokens (no mask tokens) # At t=0 (clean): almost no masking t_clean = torch.full((B,), 0.01) z_clean = mdlm_forward_sample(x_0, t_clean, mask_id, V) n_masked_clean = (z_clean == mask_id).sum().item() # At t=1 (noisy): almost everything masked t_noisy = torch.full((B,), 0.99) z_noisy = mdlm_forward_sample(x_0, t_noisy, mask_id, V) n_masked_noisy = (z_noisy == mask_id).sum().item() assert n_masked_noisy > n_masked_clean, \ f"More masking at t=1 ({n_masked_noisy}) vs t=0 ({n_masked_clean})" print(f"✓ MDLM forward sample test passed " f"(masked: t≈0: {n_masked_clean}, t≈1: {n_masked_noisy})") def test_counting_constraint(): """Test counting constraint logic.""" from cdd.constraints.instruction import CountingConstraint # Mock tokenizer class MockTokenizer: def encode(self, text, add_special_tokens=False): # Map single digits to token ids 100+digit if text.isdigit(): return [100 + int(text)] return [ord(c) for c in text] tokenizer = MockTokenizer() # "How many 's' in 'mississippi'?" → answer is 4 constraint = CountingConstraint(tokenizer, "mississippi", "s") assert constraint.correct_count == 4, \ f"Expected 4 's' in 'mississippi', got {constraint.correct_count}" # "How many 'z' in 'hello'?" → answer is 0 constraint2 = CountingConstraint(tokenizer, "hello", "z") assert constraint2.correct_count == 0, \ f"Expected 0 'z' in 'hello', got {constraint2.correct_count}" # "How many 'l' in 'hello'?" → answer is 2 constraint3 = CountingConstraint(tokenizer, "hello", "l") assert constraint3.correct_count == 2 print("✓ Counting constraint test passed") def test_full_sampling_mock(): """Test the full sampling loop with a mock model.""" from cdd.samplers.cdd_sampler import CDDSampler, ALMConfig B, L, V = 1, 8, 20 # Mock model that returns random logits class MockModel(torch.nn.Module): def __init__(self): super().__init__() self.config = type('Config', (), { 'vocab_size': V, 'model_length': L, 'hidden_dim': 64, })() self.backbone = type('Backbone', (), { 'vocab_embed': type('Embed', (), { 'embedding': torch.randn(V, 64) })() })() def forward(self, input_ids=None, timesteps=None, **kwargs): B = input_ids.shape[0] logits = torch.randn(B, L, V) return type('Output', (), {'logits': logits})() class MockTokenizer: mask_token_id = V - 1 def decode(self, ids, skip_special_tokens=True): return "mock output text" def encode(self, text, add_special_tokens=False, max_length=None, truncation=False): return [1, 2, 3] model = MockModel() tokenizer = MockTokenizer() # No constraint def no_constraint(x): return torch.tensor(0.0) config = ALMConfig( outer_iter_max=2, inner_iter_max=2, eta=0.1, ) sampler = CDDSampler( model=model, tokenizer=tokenizer, constraint_fn=no_constraint, alm_config=config, diffusion_type="mdlm", num_timesteps=5, # Very few steps for testing seq_length=L, device="cpu", ) result = sampler.sample(batch_size=1, apply_constraints=False) assert "sequences" in result assert "text" in result assert result["sequences"].shape == (1, L) assert len(result["text"]) == 1 # Test with constraints result_constrained = sampler.sample(batch_size=1, apply_constraints=True) assert result_constrained["sequences"].shape == (1, L) # Test with prefix prefix = torch.tensor([[1, 2, 3]], dtype=torch.long) result_prefix = sampler.sample( batch_size=1, prefix_ids=prefix, apply_constraints=False, ) assert torch.all(result_prefix["sequences"][0, :3] == prefix[0]) print("✓ Full sampling loop test passed") def run_all_tests(): """Run all unit tests.""" print("=" * 60) print("CDD Implementation Unit Tests") print("=" * 60) print() tests = [ test_noise_schedule, test_simplex_projection, test_mdlm_posterior, test_udlm_posterior, test_gumbel_softmax, test_alm_projection, test_mdlm_forward_sample, test_counting_constraint, test_full_sampling_mock, ] passed = 0 failed = 0 for test in tests: try: test() passed += 1 except Exception as e: print(f"✗ {test.__name__} FAILED: {e}") import traceback traceback.print_exc() failed += 1 print() print("=" * 60) print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") print("=" * 60) return failed == 0 if __name__ == "__main__": success = run_all_tests() sys.exit(0 if success else 1)