File size: 12,587 Bytes
2d0a056 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 | """
Unit tests for CDD implementation.
Tests core components without requiring GPU or large model downloads:
1. Noise schedule correctness
2. Posterior computation (MDLM and UDLM)
3. Gumbel-Softmax relaxation
4. ALM projection convergence
5. Simplex projection
6. Constraint function interfaces
"""
import sys
import torch
import torch.nn.functional as F
import numpy as np
# Add parent to path
sys.path.insert(0, '/app')
def test_noise_schedule():
"""Test log-linear noise schedule properties."""
from cdd.utils.noise_schedule import log_linear_schedule
t = torch.linspace(0, 1, 100)
alpha = log_linear_schedule(t)
# Ξ± should be decreasing
assert torch.all(alpha[:-1] >= alpha[1:]), "Alpha should be monotonically decreasing"
# Ξ±(0) β 1.0
assert alpha[0] > 0.99, f"Alpha at t=0 should be β1.0, got {alpha[0]:.4f}"
# Ξ±(1) β eps
assert alpha[-1] < 0.001, f"Alpha at t=1 should be β0, got {alpha[-1]:.6f}"
print("β Noise schedule test passed")
def test_simplex_projection():
"""Test simplex projection correctness."""
from cdd.utils.noise_schedule import project_to_simplex
# Random input
x = torch.randn(2, 5, 10)
projected = project_to_simplex(x, dim=-1)
# Should be non-negative
assert torch.all(projected >= -1e-6), "Projected values should be non-negative"
# Should sum to 1
sums = projected.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \
f"Projected should sum to 1, got sums in [{sums.min():.4f}, {sums.max():.4f}]"
# Already on simplex should not change much
y = F.softmax(torch.randn(3, 10), dim=-1)
y_proj = project_to_simplex(y, dim=-1)
assert torch.allclose(y, y_proj, atol=1e-4), "Simplex input should be unchanged"
print("β Simplex projection test passed")
def test_mdlm_posterior():
"""Test MDLM posterior computation."""
from cdd.utils.noise_schedule import mdlm_posterior, log_linear_schedule
B, L, V = 2, 8, 20
mask_id = V - 1
# Create mixed masked/unmasked sequence
z_t = torch.randint(0, V-1, (B, L))
z_t[:, 3:6] = mask_id # Mask positions 3-5
x_theta = F.softmax(torch.randn(B, L, V), dim=-1)
t = torch.tensor([0.5, 0.5])
s = torch.tensor([0.4, 0.4])
alpha_t = log_linear_schedule(t)
alpha_s = log_linear_schedule(s)
posterior = mdlm_posterior(z_t, x_theta, alpha_t, alpha_s, mask_id)
# Shape check
assert posterior.shape == (B, L, V), f"Expected {(B,L,V)}, got {posterior.shape}"
# Should be valid probability distributions
sums = posterior.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \
f"Posterior should sum to 1, got {sums}"
# Non-negative
assert torch.all(posterior >= -1e-6), "Posterior should be non-negative"
# Unmasked positions should be one-hot at current token
for b in range(B):
for pos in [0, 1, 2, 6, 7]: # Unmasked positions
token = z_t[b, pos].item()
assert posterior[b, pos, token] > 0.99, \
f"Unmasked position should be one-hot, got {posterior[b, pos, token]:.4f}"
print("β MDLM posterior test passed")
def test_udlm_posterior():
"""Test UDLM posterior computation."""
from cdd.utils.noise_schedule import udlm_posterior, log_linear_schedule
B, L, V = 2, 8, 20
z_t = torch.randint(0, V, (B, L))
x_theta = F.softmax(torch.randn(B, L, V), dim=-1)
t = torch.tensor([0.5, 0.5])
s = torch.tensor([0.4, 0.4])
alpha_t = log_linear_schedule(t)
alpha_s = log_linear_schedule(s)
posterior = udlm_posterior(z_t, x_theta, alpha_t, alpha_s)
# Shape check
assert posterior.shape == (B, L, V)
# Valid probability distribution
sums = posterior.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-4), \
f"UDLM posterior should sum to 1, got sums: min={sums.min():.4f}, max={sums.max():.4f}"
# Non-negative
assert torch.all(posterior >= -1e-6), "UDLM posterior should be non-negative"
print("β UDLM posterior test passed")
def test_gumbel_softmax():
"""Test Gumbel-Softmax relaxation properties."""
from cdd.samplers.cdd_sampler import gumbel_softmax
# Create peaked distribution
logits = torch.zeros(2, 5, 10)
logits[:, :, 3] = 10.0 # Peak at index 3
log_probs = F.log_softmax(logits, dim=-1)
# Low temperature should approximate argmax
result = gumbel_softmax(log_probs, temperature=0.01, use_noise=False)
argmax_vals = result.argmax(dim=-1)
assert torch.all(argmax_vals == 3), \
"Low temperature Gumbel-Softmax should approximate argmax"
# High temperature should be more uniform
result_high = gumbel_softmax(log_probs, temperature=100.0, use_noise=False)
max_probs = result_high.max(dim=-1).values
assert torch.all(max_probs < 0.5), \
"High temperature should give more uniform distribution"
# Output should be valid probabilities
assert torch.all(result >= 0), "Output should be non-negative"
sums = result.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=1e-5), \
"Output should sum to 1"
print("β Gumbel-Softmax test passed")
def test_alm_projection():
"""Test ALM projection converges to satisfy a simple constraint."""
from cdd.samplers.cdd_sampler import alm_projection, ALMConfig
B, L, V = 1, 4, 10
# Create input distribution
x_theta = F.softmax(torch.randn(B, L, V), dim=-1)
# Simple constraint: argmax of position 0 should be token 5
# g(xΜ) = 1 - xΜ[0, 0, 5] (violation if token 5 doesn't have high prob)
def constraint_fn(x_tilde):
prob_target = x_tilde[0, 0, 5]
return F.relu(0.9 - prob_target) # Want prob > 0.9
config = ALMConfig(
lambda_init=0.0,
mu_init=1.0,
mu_max=100.0,
outer_iter_max=50,
inner_iter_max=5,
eta=0.5,
gumbel_temperature=0.1,
use_gumbel_noise=False, # Deterministic for testing
)
x_proj = alm_projection(x_theta, constraint_fn, config)
# Check shape preserved
assert x_proj.shape == x_theta.shape, "Shape should be preserved"
# Check valid probabilities
assert torch.all(x_proj >= 0), "Projected should be non-negative"
sums = x_proj.sum(dim=-1)
assert torch.allclose(sums, torch.ones_like(sums), atol=0.1), \
f"Projected should approximately sum to 1, got {sums}"
# Check constraint is better satisfied
initial_violation = constraint_fn(x_theta).item()
final_violation = constraint_fn(x_proj).item()
assert final_violation <= initial_violation + 0.1, \
f"Projection should reduce violation: {initial_violation:.4f} β {final_violation:.4f}"
print(f"β ALM projection test passed (violation: {initial_violation:.4f} β {final_violation:.4f})")
def test_mdlm_forward_sample():
"""Test MDLM forward noising process."""
from cdd.utils.noise_schedule import mdlm_forward_sample
B, L = 4, 16
V = 50
mask_id = V - 1
x_0 = torch.randint(0, V-1, (B, L)) # Clean tokens (no mask tokens)
# At t=0 (clean): almost no masking
t_clean = torch.full((B,), 0.01)
z_clean = mdlm_forward_sample(x_0, t_clean, mask_id, V)
n_masked_clean = (z_clean == mask_id).sum().item()
# At t=1 (noisy): almost everything masked
t_noisy = torch.full((B,), 0.99)
z_noisy = mdlm_forward_sample(x_0, t_noisy, mask_id, V)
n_masked_noisy = (z_noisy == mask_id).sum().item()
assert n_masked_noisy > n_masked_clean, \
f"More masking at t=1 ({n_masked_noisy}) vs t=0 ({n_masked_clean})"
print(f"β MDLM forward sample test passed "
f"(masked: tβ0: {n_masked_clean}, tβ1: {n_masked_noisy})")
def test_counting_constraint():
"""Test counting constraint logic."""
from cdd.constraints.instruction import CountingConstraint
# Mock tokenizer
class MockTokenizer:
def encode(self, text, add_special_tokens=False):
# Map single digits to token ids 100+digit
if text.isdigit():
return [100 + int(text)]
return [ord(c) for c in text]
tokenizer = MockTokenizer()
# "How many 's' in 'mississippi'?" β answer is 4
constraint = CountingConstraint(tokenizer, "mississippi", "s")
assert constraint.correct_count == 4, \
f"Expected 4 's' in 'mississippi', got {constraint.correct_count}"
# "How many 'z' in 'hello'?" β answer is 0
constraint2 = CountingConstraint(tokenizer, "hello", "z")
assert constraint2.correct_count == 0, \
f"Expected 0 'z' in 'hello', got {constraint2.correct_count}"
# "How many 'l' in 'hello'?" β answer is 2
constraint3 = CountingConstraint(tokenizer, "hello", "l")
assert constraint3.correct_count == 2
print("β Counting constraint test passed")
def test_full_sampling_mock():
"""Test the full sampling loop with a mock model."""
from cdd.samplers.cdd_sampler import CDDSampler, ALMConfig
B, L, V = 1, 8, 20
# Mock model that returns random logits
class MockModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.config = type('Config', (), {
'vocab_size': V,
'model_length': L,
'hidden_dim': 64,
})()
self.backbone = type('Backbone', (), {
'vocab_embed': type('Embed', (), {
'embedding': torch.randn(V, 64)
})()
})()
def forward(self, input_ids=None, timesteps=None, **kwargs):
B = input_ids.shape[0]
logits = torch.randn(B, L, V)
return type('Output', (), {'logits': logits})()
class MockTokenizer:
mask_token_id = V - 1
def decode(self, ids, skip_special_tokens=True):
return "mock output text"
def encode(self, text, add_special_tokens=False, max_length=None, truncation=False):
return [1, 2, 3]
model = MockModel()
tokenizer = MockTokenizer()
# No constraint
def no_constraint(x):
return torch.tensor(0.0)
config = ALMConfig(
outer_iter_max=2,
inner_iter_max=2,
eta=0.1,
)
sampler = CDDSampler(
model=model,
tokenizer=tokenizer,
constraint_fn=no_constraint,
alm_config=config,
diffusion_type="mdlm",
num_timesteps=5, # Very few steps for testing
seq_length=L,
device="cpu",
)
result = sampler.sample(batch_size=1, apply_constraints=False)
assert "sequences" in result
assert "text" in result
assert result["sequences"].shape == (1, L)
assert len(result["text"]) == 1
# Test with constraints
result_constrained = sampler.sample(batch_size=1, apply_constraints=True)
assert result_constrained["sequences"].shape == (1, L)
# Test with prefix
prefix = torch.tensor([[1, 2, 3]], dtype=torch.long)
result_prefix = sampler.sample(
batch_size=1, prefix_ids=prefix, apply_constraints=False,
)
assert torch.all(result_prefix["sequences"][0, :3] == prefix[0])
print("β Full sampling loop test passed")
def run_all_tests():
"""Run all unit tests."""
print("=" * 60)
print("CDD Implementation Unit Tests")
print("=" * 60)
print()
tests = [
test_noise_schedule,
test_simplex_projection,
test_mdlm_posterior,
test_udlm_posterior,
test_gumbel_softmax,
test_alm_projection,
test_mdlm_forward_sample,
test_counting_constraint,
test_full_sampling_mock,
]
passed = 0
failed = 0
for test in tests:
try:
test()
passed += 1
except Exception as e:
print(f"β {test.__name__} FAILED: {e}")
import traceback
traceback.print_exc()
failed += 1
print()
print("=" * 60)
print(f"Results: {passed} passed, {failed} failed out of {len(tests)}")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
|