File size: 3,676 Bytes
14b57af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pytest
import torch
from ltx_video.schedulers.rf import RectifiedFlowScheduler


def init_latents_and_scheduler(sampler):
    batch_size, n_tokens, n_channels = 2, 4096, 128
    num_steps = 20
    scheduler = RectifiedFlowScheduler(
        sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
    )
    latents = torch.randn(size=(batch_size, n_tokens, n_channels))
    scheduler.set_timesteps(num_inference_steps=num_steps, samples=latents)
    return scheduler, latents


@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_default_behavior(sampler):
    """
    Test the case of a single timestep from the list of timesteps.
    """
    scheduler, latents = init_latents_and_scheduler(sampler)

    for i, t in enumerate(scheduler.timesteps):
        noise_pred = torch.randn_like(latents)
        denoised_latents = scheduler.step(
            noise_pred,
            t,
            latents,
            return_dict=False,
        )[0]

        # Verify the denoising
        next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
        dt = t - next_t
        expected_denoised_latents = latents - dt * noise_pred
        assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)


@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_per_token(sampler):
    """
    Test the case of a timestep per token (from the list of timesteps).
    Some tokens are set with timestep of 0.
    """
    scheduler, latents = init_latents_and_scheduler(sampler)
    batch_size, n_tokens = latents.shape[:2]
    for i, t in enumerate(scheduler.timesteps):
        timesteps = torch.full((batch_size, n_tokens), t)
        timesteps[:, 0] = 0.0
        noise_pred = torch.randn_like(latents)
        denoised_latents = scheduler.step(
            noise_pred,
            timesteps,
            latents,
            return_dict=False,
        )[0]

        # Verify the denoising
        next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
        next_timesteps = torch.full((batch_size, n_tokens), next_t)
        dt = timesteps - next_timesteps
        expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
        assert torch.allclose(
            denoised_latents[:, 1:], expected_denoised_latents[:, 1:], atol=1e-06
        )
        assert torch.allclose(denoised_latents[:, 0], latents[:, 0], atol=1e-06)


@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_t_not_in_list(sampler):
    """
    Test the case of a timestep per token NOT from the list of timesteps.
    """
    scheduler, latents = init_latents_and_scheduler(sampler)
    batch_size, n_tokens = latents.shape[:2]
    for i in range(len(scheduler.timesteps)):
        if i < len(scheduler.timesteps) - 1:
            t = (scheduler.timesteps[i] + scheduler.timesteps[i + 1]) / 2
        else:
            t = scheduler.timesteps[i] / 2
        timesteps = torch.full((batch_size, n_tokens), t)
        noise_pred = torch.randn_like(latents)
        denoised_latents = scheduler.step(
            noise_pred,
            timesteps,
            latents,
            return_dict=False,
        )[0]

        # Verify the denoising
        next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
        next_timesteps = torch.full((batch_size, n_tokens), next_t)
        dt = timesteps - next_timesteps
        expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
        assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)