File size: 3,676 Bytes
14b57af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 | import pytest
import torch
from ltx_video.schedulers.rf import RectifiedFlowScheduler
def init_latents_and_scheduler(sampler):
batch_size, n_tokens, n_channels = 2, 4096, 128
num_steps = 20
scheduler = RectifiedFlowScheduler(
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
)
latents = torch.randn(size=(batch_size, n_tokens, n_channels))
scheduler.set_timesteps(num_inference_steps=num_steps, samples=latents)
return scheduler, latents
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_default_behavior(sampler):
"""
Test the case of a single timestep from the list of timesteps.
"""
scheduler, latents = init_latents_and_scheduler(sampler)
for i, t in enumerate(scheduler.timesteps):
noise_pred = torch.randn_like(latents)
denoised_latents = scheduler.step(
noise_pred,
t,
latents,
return_dict=False,
)[0]
# Verify the denoising
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
dt = t - next_t
expected_denoised_latents = latents - dt * noise_pred
assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_per_token(sampler):
"""
Test the case of a timestep per token (from the list of timesteps).
Some tokens are set with timestep of 0.
"""
scheduler, latents = init_latents_and_scheduler(sampler)
batch_size, n_tokens = latents.shape[:2]
for i, t in enumerate(scheduler.timesteps):
timesteps = torch.full((batch_size, n_tokens), t)
timesteps[:, 0] = 0.0
noise_pred = torch.randn_like(latents)
denoised_latents = scheduler.step(
noise_pred,
timesteps,
latents,
return_dict=False,
)[0]
# Verify the denoising
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
next_timesteps = torch.full((batch_size, n_tokens), next_t)
dt = timesteps - next_timesteps
expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
assert torch.allclose(
denoised_latents[:, 1:], expected_denoised_latents[:, 1:], atol=1e-06
)
assert torch.allclose(denoised_latents[:, 0], latents[:, 0], atol=1e-06)
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_t_not_in_list(sampler):
"""
Test the case of a timestep per token NOT from the list of timesteps.
"""
scheduler, latents = init_latents_and_scheduler(sampler)
batch_size, n_tokens = latents.shape[:2]
for i in range(len(scheduler.timesteps)):
if i < len(scheduler.timesteps) - 1:
t = (scheduler.timesteps[i] + scheduler.timesteps[i + 1]) / 2
else:
t = scheduler.timesteps[i] / 2
timesteps = torch.full((batch_size, n_tokens), t)
noise_pred = torch.randn_like(latents)
denoised_latents = scheduler.step(
noise_pred,
timesteps,
latents,
return_dict=False,
)[0]
# Verify the denoising
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
next_timesteps = torch.full((batch_size, n_tokens), next_t)
dt = timesteps - next_timesteps
expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)
|