| """ |
| Unit tests for CDD implementation. |
| |
| Tests core components without requiring GPU or large model downloads: |
| 1. Noise schedule correctness |
| 2. Posterior computation (MDLM and UDLM) |
| 3. Gumbel-Softmax relaxation |
| 4. ALM projection convergence |
| 5. Simplex projection |
| 6. Constraint function interfaces |
| """ |
|
|
| import sys |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
|
|
| |
| sys.path.insert(0, '/app') |
|
|
|
|
| def test_noise_schedule(): |
| """Test log-linear noise schedule properties.""" |
| from cdd.utils.noise_schedule import log_linear_schedule |
| |
| t = torch.linspace(0, 1, 100) |
| alpha = log_linear_schedule(t) |
| |
| |
| assert torch.all(alpha[:-1] >= alpha[1:]), "Alpha should be monotonically decreasing" |
| |
| |
| assert alpha[0] > 0.99, f"Alpha at t=0 should be ≈1.0, got {alpha[0]:.4f}" |
| |
| |
| assert alpha[-1] < 0.001, f"Alpha at t=1 should be ≈0, got {alpha[-1]:.6f}" |
| |
| print("✓ Noise schedule test passed") |
|
|
|
|
| def test_simplex_projection(): |
| """Test simplex projection correctness.""" |
| from cdd.utils.noise_schedule import project_to_simplex |
| |
| |
| x = torch.randn(2, 5, 10) |
| projected = project_to_simplex(x, dim=-1) |
| |
| |
| assert torch.all(projected >= -1e-6), "Projected values should be non-negative" |
| |
| |
| sums = projected.sum(dim=-1) |
| assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \ |
| f"Projected should sum to 1, got sums in [{sums.min():.4f}, {sums.max():.4f}]" |
| |
| |
| y = F.softmax(torch.randn(3, 10), dim=-1) |
| y_proj = project_to_simplex(y, dim=-1) |
| assert torch.allclose(y, y_proj, atol=1e-4), "Simplex input should be unchanged" |
| |
| print("✓ Simplex projection test passed") |
|
|
|
|
| def test_mdlm_posterior(): |
| """Test MDLM posterior computation.""" |
| from cdd.utils.noise_schedule import mdlm_posterior, log_linear_schedule |
| |
| B, L, V = 2, 8, 20 |
| mask_id = V - 1 |
| |
| |
| z_t = torch.randint(0, V-1, (B, L)) |
| z_t[:, 3:6] = mask_id |
| |
| x_theta = F.softmax(torch.randn(B, L, V), dim=-1) |
| |
| t = torch.tensor([0.5, 0.5]) |
| s = torch.tensor([0.4, 0.4]) |
| alpha_t = log_linear_schedule(t) |
| alpha_s = log_linear_schedule(s) |
| |
| posterior = mdlm_posterior(z_t, x_theta, alpha_t, alpha_s, mask_id) |
| |
| |
| assert posterior.shape == (B, L, V), f"Expected {(B,L,V)}, got {posterior.shape}" |
| |
| |
| sums = posterior.sum(dim=-1) |
| assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \ |
| f"Posterior should sum to 1, got {sums}" |
| |
| |
| assert torch.all(posterior >= -1e-6), "Posterior should be non-negative" |
| |
| |
| for b in range(B): |
| for pos in [0, 1, 2, 6, 7]: |
| token = z_t[b, pos].item() |
| assert posterior[b, pos, token] > 0.99, \ |
| f"Unmasked position should be one-hot, got {posterior[b, pos, token]:.4f}" |
| |
| print("✓ MDLM posterior test passed") |
|
|
|
|
| def test_udlm_posterior(): |
| """Test UDLM posterior computation.""" |
| from cdd.utils.noise_schedule import udlm_posterior, log_linear_schedule |
| |
| B, L, V = 2, 8, 20 |
| |
| z_t = torch.randint(0, V, (B, L)) |
| x_theta = F.softmax(torch.randn(B, L, V), dim=-1) |
| |
| t = torch.tensor([0.5, 0.5]) |
| s = torch.tensor([0.4, 0.4]) |
| alpha_t = log_linear_schedule(t) |
| alpha_s = log_linear_schedule(s) |
| |
| posterior = udlm_posterior(z_t, x_theta, alpha_t, alpha_s) |
| |
| |
| assert posterior.shape == (B, L, V) |
| |
| |
| sums = posterior.sum(dim=-1) |
| assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \ |
| f"UDLM posterior should sum to 1, got sums: min={sums.min():.4f}, max={sums.max():.4f}" |
| |
| |
| assert torch.all(posterior >= -1e-6), "UDLM posterior should be non-negative" |
| |
| print("✓ UDLM posterior test passed") |
|
|
|
|
| def test_gumbel_softmax(): |
| """Test Gumbel-Softmax relaxation properties.""" |
| from cdd.samplers.cdd_sampler import gumbel_softmax |
| |
| |
| logits = torch.zeros(2, 5, 10) |
| logits[:, :, 3] = 10.0 |
| log_probs = F.log_softmax(logits, dim=-1) |
| |
| |
| result = gumbel_softmax(log_probs, temperature=0.01, use_noise=False) |
| argmax_vals = result.argmax(dim=-1) |
| assert torch.all(argmax_vals == 3), \ |
| "Low temperature Gumbel-Softmax should approximate argmax" |
| |
| |
| result_high = gumbel_softmax(log_probs, temperature=100.0, use_noise=False) |
| max_probs = result_high.max(dim=-1).values |
| assert torch.all(max_probs < 0.5), \ |
| "High temperature should give more uniform distribution" |
| |
| |
| assert torch.all(result >= 0), "Output should be non-negative" |
| sums = result.sum(dim=-1) |
| assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \ |
| "Output should sum to 1" |
| |
| print("✓ Gumbel-Softmax test passed") |
|
|
|
|
| def test_alm_projection(): |
| """Test ALM projection converges to satisfy a simple constraint.""" |
| from cdd.samplers.cdd_sampler import alm_projection, ALMConfig |
| |
| B, L, V = 1, 4, 10 |
| |
| |
| x_theta = F.softmax(torch.randn(B, L, V), dim=-1) |
| |
| |
| |
| def constraint_fn(x_tilde): |
| prob_target = x_tilde[0, 0, 5] |
| return F.relu(0.9 - prob_target) |
| |
| config = ALMConfig( |
| lambda_init=0.0, |
| mu_init=1.0, |
| mu_max=100.0, |
| outer_iter_max=50, |
| inner_iter_max=5, |
| eta=0.5, |
| gumbel_temperature=0.1, |
| use_gumbel_noise=False, |
| ) |
| |
| x_proj = alm_projection(x_theta, constraint_fn, config) |
| |
| |
| assert x_proj.shape == x_theta.shape, "Shape should be preserved" |
| |
| |
| assert torch.all(x_proj >= 0), "Projected should be non-negative" |
| sums = x_proj.sum(dim=-1) |
| assert torch.allclose(sums, torch.ones_like(sums), atol=0.1), \ |
| f"Projected should approximately sum to 1, got {sums}" |
| |
| |
| initial_violation = constraint_fn(x_theta).item() |
| final_violation = constraint_fn(x_proj).item() |
| assert final_violation <= initial_violation + 0.1, \ |
| f"Projection should reduce violation: {initial_violation:.4f} → {final_violation:.4f}" |
| |
| print(f"✓ ALM projection test passed (violation: {initial_violation:.4f} → {final_violation:.4f})") |
|
|
|
|
| def test_mdlm_forward_sample(): |
| """Test MDLM forward noising process.""" |
| from cdd.utils.noise_schedule import mdlm_forward_sample |
| |
| B, L = 4, 16 |
| V = 50 |
| mask_id = V - 1 |
| |
| x_0 = torch.randint(0, V-1, (B, L)) |
| |
| |
| t_clean = torch.full((B,), 0.01) |
| z_clean = mdlm_forward_sample(x_0, t_clean, mask_id, V) |
| n_masked_clean = (z_clean == mask_id).sum().item() |
| |
| |
| t_noisy = torch.full((B,), 0.99) |
| z_noisy = mdlm_forward_sample(x_0, t_noisy, mask_id, V) |
| n_masked_noisy = (z_noisy == mask_id).sum().item() |
| |
| assert n_masked_noisy > n_masked_clean, \ |
| f"More masking at t=1 ({n_masked_noisy}) vs t=0 ({n_masked_clean})" |
| |
| print(f"✓ MDLM forward sample test passed " |
| f"(masked: t≈0: {n_masked_clean}, t≈1: {n_masked_noisy})") |
|
|
|
|
| def test_counting_constraint(): |
| """Test counting constraint logic.""" |
| from cdd.constraints.instruction import CountingConstraint |
| |
| |
| class MockTokenizer: |
| def encode(self, text, add_special_tokens=False): |
| |
| if text.isdigit(): |
| return [100 + int(text)] |
| return [ord(c) for c in text] |
| |
| tokenizer = MockTokenizer() |
| |
| |
| constraint = CountingConstraint(tokenizer, "mississippi", "s") |
| assert constraint.correct_count == 4, \ |
| f"Expected 4 's' in 'mississippi', got {constraint.correct_count}" |
| |
| |
| constraint2 = CountingConstraint(tokenizer, "hello", "z") |
| assert constraint2.correct_count == 0, \ |
| f"Expected 0 'z' in 'hello', got {constraint2.correct_count}" |
| |
| |
| constraint3 = CountingConstraint(tokenizer, "hello", "l") |
| assert constraint3.correct_count == 2 |
| |
| print("✓ Counting constraint test passed") |
|
|
|
|
| def test_full_sampling_mock(): |
| """Test the full sampling loop with a mock model.""" |
| from cdd.samplers.cdd_sampler import CDDSampler, ALMConfig |
| |
| B, L, V = 1, 8, 20 |
| |
| |
| class MockModel(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.config = type('Config', (), { |
| 'vocab_size': V, |
| 'model_length': L, |
| 'hidden_dim': 64, |
| })() |
| self.backbone = type('Backbone', (), { |
| 'vocab_embed': type('Embed', (), { |
| 'embedding': torch.randn(V, 64) |
| })() |
| })() |
| |
| def forward(self, input_ids=None, timesteps=None, **kwargs): |
| B = input_ids.shape[0] |
| logits = torch.randn(B, L, V) |
| return type('Output', (), {'logits': logits})() |
| |
| class MockTokenizer: |
| mask_token_id = V - 1 |
| def decode(self, ids, skip_special_tokens=True): |
| return "mock output text" |
| def encode(self, text, add_special_tokens=False, max_length=None, truncation=False): |
| return [1, 2, 3] |
| |
| model = MockModel() |
| tokenizer = MockTokenizer() |
| |
| |
| def no_constraint(x): |
| return torch.tensor(0.0) |
| |
| config = ALMConfig( |
| outer_iter_max=2, |
| inner_iter_max=2, |
| eta=0.1, |
| ) |
| |
| sampler = CDDSampler( |
| model=model, |
| tokenizer=tokenizer, |
| constraint_fn=no_constraint, |
| alm_config=config, |
| diffusion_type="mdlm", |
| num_timesteps=5, |
| seq_length=L, |
| device="cpu", |
| ) |
| |
| result = sampler.sample(batch_size=1, apply_constraints=False) |
| |
| assert "sequences" in result |
| assert "text" in result |
| assert result["sequences"].shape == (1, L) |
| assert len(result["text"]) == 1 |
| |
| |
| result_constrained = sampler.sample(batch_size=1, apply_constraints=True) |
| assert result_constrained["sequences"].shape == (1, L) |
| |
| |
| prefix = torch.tensor([[1, 2, 3]], dtype=torch.long) |
| result_prefix = sampler.sample( |
| batch_size=1, prefix_ids=prefix, apply_constraints=False, |
| ) |
| assert torch.all(result_prefix["sequences"][0, :3] == prefix[0]) |
| |
| print("✓ Full sampling loop test passed") |
|
|
|
|
| def run_all_tests(): |
| """Run all unit tests.""" |
| print("=" * 60) |
| print("CDD Implementation Unit Tests") |
| print("=" * 60) |
| print() |
| |
| tests = [ |
| test_noise_schedule, |
| test_simplex_projection, |
| test_mdlm_posterior, |
| test_udlm_posterior, |
| test_gumbel_softmax, |
| test_alm_projection, |
| test_mdlm_forward_sample, |
| test_counting_constraint, |
| test_full_sampling_mock, |
| ] |
| |
| passed = 0 |
| failed = 0 |
| |
| for test in tests: |
| try: |
| test() |
| passed += 1 |
| except Exception as e: |
| print(f"✗ {test.__name__} FAILED: {e}") |
| import traceback |
| traceback.print_exc() |
| failed += 1 |
| |
| print() |
| print("=" * 60) |
| print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") |
| print("=" * 60) |
| |
| return failed == 0 |
|
|
|
|
| if __name__ == "__main__": |
| success = run_all_tests() |
| sys.exit(0 if success else 1) |
|
|