File size: 3,671 Bytes
2d7e335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Tests for Noise Scheduler."""

import torch
import pytest
from diffusion_llm.model.noise_scheduler import NoiseScheduler


class TestNoiseScheduler:
    """Test suite for the NoiseScheduler."""

    def test_cosine_schedule(self):
        """Test cosine noise schedule creation."""
        scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="cosine")
        assert scheduler.betas.shape == (1000,)
        assert (scheduler.betas > 0).all()
        assert (scheduler.betas < 1).all()

    def test_linear_schedule(self):
        """Test linear noise schedule creation."""
        scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="linear")
        assert scheduler.betas.shape == (1000,)
        assert scheduler.betas[0] < scheduler.betas[-1]  # Increasing

    def test_sigmoid_schedule(self):
        """Test sigmoid noise schedule creation."""
        scheduler = NoiseScheduler(n_timesteps=1000, schedule_type="sigmoid")
        assert scheduler.betas.shape == (1000,)
        assert (scheduler.betas > 0).all()

    def test_add_noise(self):
        """Test forward diffusion (adding noise)."""
        scheduler = NoiseScheduler(n_timesteps=1000)
        x_0 = torch.randn(2, 10, 64)  # batch=2, seq=10, d=64
        noise = torch.randn_like(x_0)
        t = torch.tensor([0, 500])

        x_t = scheduler.add_noise(x_0, noise, t)
        assert x_t.shape == x_0.shape
        # At t=0, x_t should be close to x_0
        # At t=500, x_t should be significantly different

    def test_loss_target_epsilon(self):
        """Test epsilon prediction target."""
        scheduler = NoiseScheduler(prediction_type="epsilon")
        x_0 = torch.randn(2, 10, 64)
        noise = torch.randn_like(x_0)
        t = torch.tensor([100, 500])

        target = scheduler.compute_loss_target(x_0, noise, t)
        assert torch.allclose(target, noise)

    def test_loss_target_x0(self):
        """Test x0 prediction target."""
        scheduler = NoiseScheduler(prediction_type="x0")
        x_0 = torch.randn(2, 10, 64)
        noise = torch.randn_like(x_0)
        t = torch.tensor([100, 500])

        target = scheduler.compute_loss_target(x_0, noise, t)
        assert torch.allclose(target, x_0)

    def test_predict_x0_from_epsilon(self):
        """Test x0 prediction from epsilon."""
        scheduler = NoiseScheduler(prediction_type="epsilon")
        x_0 = torch.randn(2, 10, 64)
        noise = torch.randn_like(x_0)
        t = torch.tensor([100])

        x_t = scheduler.add_noise(x_0, noise, t)
        x_0_pred = scheduler.predict_x0_from_epsilon(x_t, noise, t)
        # Should be close to original x_0
        assert x_0_pred.shape == x_0.shape

    def test_ddpm_step(self):
        """Test single DDPM reverse step."""
        scheduler = NoiseScheduler(n_timesteps=1000)
        x_t = torch.randn(2, 10, 64)
        model_output = torch.randn_like(x_t)
        t = torch.tensor([500, 500])

        x_prev = scheduler.step_ddpm(model_output, x_t, t)
        assert x_prev.shape == x_t.shape

    def test_ddim_step(self):
        """Test single DDIM reverse step."""
        scheduler = NoiseScheduler(n_timesteps=1000)
        x_t = torch.randn(2, 10, 64)
        model_output = torch.randn_like(x_t)

        x_prev = scheduler.step_ddim(model_output, x_t, t=500, t_prev=400)
        assert x_prev.shape == x_t.shape

    def test_timestep_schedule(self):
        """Test inference timestep schedule."""
        scheduler = NoiseScheduler(n_timesteps=1000)
        schedule = scheduler.get_timestep_schedule(n_inference_steps=50)
        assert len(schedule) > 0
        assert schedule[0] > schedule[-1]  # Descending order