| """Tests for Noise Scheduler.""" |
|
|
| import torch |
| import pytest |
| from diffusion_llm.model.noise_scheduler import NoiseScheduler |
|
|
|
|
| class TestNoiseScheduler: |
| """Test suite for the NoiseScheduler.""" |
|
|
| def test_cosine_schedule(self): |
| """Test cosine noise schedule creation.""" |
| scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="cosine") |
| assert scheduler.betas.shape == (1000,) |
| assert (scheduler.betas > 0).all() |
| assert (scheduler.betas < 1).all() |
|
|
| def test_linear_schedule(self): |
| """Test linear noise schedule creation.""" |
| scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="linear") |
| assert scheduler.betas.shape == (1000,) |
| assert scheduler.betas[0] < scheduler.betas[-1] |
|
|
| def test_sigmoid_schedule(self): |
| """Test sigmoid noise schedule creation.""" |
| scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="sigmoid") |
| assert scheduler.betas.shape == (1000,) |
| assert (scheduler.betas > 0).all() |
|
|
| def test_add_noise(self): |
| """Test forward diffusion (adding noise).""" |
| scheduler = NoiseScheduler(n_timesteps=1000) |
| x_0 = torch.randn(2, 10, 64) |
| noise = torch.randn_like(x_0) |
| t = torch.tensor([0, 500]) |
|
|
| x_t = scheduler.add_noise(x_0, noise, t) |
| assert x_t.shape == x_0.shape |
| |
| |
|
|
| def test_loss_target_epsilon(self): |
| """Test epsilon prediction target.""" |
| scheduler = NoiseScheduler(prediction_type="epsilon") |
| x_0 = torch.randn(2, 10, 64) |
| noise = torch.randn_like(x_0) |
| t = torch.tensor([100, 500]) |
|
|
| target = scheduler.compute_loss_target(x_0, noise, t) |
| assert torch.allclose(target, noise) |
|
|
| def test_loss_target_x0(self): |
| """Test x0 prediction target.""" |
| scheduler = NoiseScheduler(prediction_type="x0") |
| x_0 = torch.randn(2, 10, 64) |
| noise = torch.randn_like(x_0) |
| t = torch.tensor([100, 500]) |
|
|
| target = scheduler.compute_loss_target(x_0, noise, t) |
| assert torch.allclose(target, x_0) |
|
|
| def test_predict_x0_from_epsilon(self): |
| """Test x0 prediction from epsilon.""" |
| scheduler = NoiseScheduler(prediction_type="epsilon") |
| x_0 = torch.randn(2, 10, 64) |
| noise = torch.randn_like(x_0) |
| t = torch.tensor([100]) |
|
|
| x_t = scheduler.add_noise(x_0, noise, t) |
| x_0_pred = scheduler.predict_x0_from_epsilon(x_t, noise, t) |
| |
| assert x_0_pred.shape == x_0.shape |
|
|
| def test_ddpm_step(self): |
| """Test single DDPM reverse step.""" |
| scheduler = NoiseScheduler(n_timesteps=1000) |
| x_t = torch.randn(2, 10, 64) |
| model_output = torch.randn_like(x_t) |
| t = torch.tensor([500, 500]) |
|
|
| x_prev = scheduler.step_ddpm(model_output, x_t, t) |
| assert x_prev.shape == x_t.shape |
|
|
| def test_ddim_step(self): |
| """Test single DDIM reverse step.""" |
| scheduler = NoiseScheduler(n_timesteps=1000) |
| x_t = torch.randn(2, 10, 64) |
| model_output = torch.randn_like(x_t) |
|
|
| x_prev = scheduler.step_ddim(model_output, x_t, t=500, t_prev=400) |
| assert x_prev.shape == x_t.shape |
|
|
| def test_timestep_schedule(self): |
| """Test inference timestep schedule.""" |
| scheduler = NoiseScheduler(n_timesteps=1000) |
| schedule = scheduler.get_timestep_schedule(n_inference_steps=50) |
| assert len(schedule) > 0 |
| assert schedule[0] > schedule[-1] |
|
|