|
|
import pytest |
|
|
import torch |
|
|
from ltx_video.schedulers.rf import RectifiedFlowScheduler |
|
|
|
|
|
|
|
|
def init_latents_and_scheduler(sampler): |
|
|
batch_size, n_tokens, n_channels = 2, 4096, 128 |
|
|
num_steps = 20 |
|
|
scheduler = RectifiedFlowScheduler( |
|
|
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic") |
|
|
) |
|
|
latents = torch.randn(size=(batch_size, n_tokens, n_channels)) |
|
|
scheduler.set_timesteps(num_inference_steps=num_steps, samples=latents) |
|
|
return scheduler, latents |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"]) |
|
|
def test_scheduler_default_behavior(sampler): |
|
|
""" |
|
|
Test the case of a single timestep from the list of timesteps. |
|
|
""" |
|
|
scheduler, latents = init_latents_and_scheduler(sampler) |
|
|
|
|
|
for i, t in enumerate(scheduler.timesteps): |
|
|
noise_pred = torch.randn_like(latents) |
|
|
denoised_latents = scheduler.step( |
|
|
noise_pred, |
|
|
t, |
|
|
latents, |
|
|
return_dict=False, |
|
|
)[0] |
|
|
|
|
|
|
|
|
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0 |
|
|
dt = t - next_t |
|
|
expected_denoised_latents = latents - dt * noise_pred |
|
|
assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"]) |
|
|
def test_scheduler_per_token(sampler): |
|
|
""" |
|
|
Test the case of a timestep per token (from the list of timesteps). |
|
|
Some tokens are set with timestep of 0. |
|
|
""" |
|
|
scheduler, latents = init_latents_and_scheduler(sampler) |
|
|
batch_size, n_tokens = latents.shape[:2] |
|
|
for i, t in enumerate(scheduler.timesteps): |
|
|
timesteps = torch.full((batch_size, n_tokens), t) |
|
|
timesteps[:, 0] = 0.0 |
|
|
noise_pred = torch.randn_like(latents) |
|
|
denoised_latents = scheduler.step( |
|
|
noise_pred, |
|
|
timesteps, |
|
|
latents, |
|
|
return_dict=False, |
|
|
)[0] |
|
|
|
|
|
|
|
|
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0 |
|
|
next_timesteps = torch.full((batch_size, n_tokens), next_t) |
|
|
dt = timesteps - next_timesteps |
|
|
expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred |
|
|
assert torch.allclose( |
|
|
denoised_latents[:, 1:], expected_denoised_latents[:, 1:], atol=1e-06 |
|
|
) |
|
|
assert torch.allclose(denoised_latents[:, 0], latents[:, 0], atol=1e-06) |
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"]) |
|
|
def test_scheduler_t_not_in_list(sampler): |
|
|
""" |
|
|
Test the case of a timestep per token NOT from the list of timesteps. |
|
|
""" |
|
|
scheduler, latents = init_latents_and_scheduler(sampler) |
|
|
batch_size, n_tokens = latents.shape[:2] |
|
|
for i in range(len(scheduler.timesteps)): |
|
|
if i < len(scheduler.timesteps) - 1: |
|
|
t = (scheduler.timesteps[i] + scheduler.timesteps[i + 1]) / 2 |
|
|
else: |
|
|
t = scheduler.timesteps[i] / 2 |
|
|
timesteps = torch.full((batch_size, n_tokens), t) |
|
|
noise_pred = torch.randn_like(latents) |
|
|
denoised_latents = scheduler.step( |
|
|
noise_pred, |
|
|
timesteps, |
|
|
latents, |
|
|
return_dict=False, |
|
|
)[0] |
|
|
|
|
|
|
|
|
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0 |
|
|
next_timesteps = torch.full((batch_size, n_tokens), next_t) |
|
|
dt = timesteps - next_timesteps |
|
|
expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred |
|
|
assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06) |
|
|
|