LTX-itv-Repository / LTX-Video /tests /test_scheduler.py
vidfom's picture
Upload folder using huggingface_hub
14b57af verified
import pytest
import torch
from ltx_video.schedulers.rf import RectifiedFlowScheduler
def init_latents_and_scheduler(sampler):
batch_size, n_tokens, n_channels = 2, 4096, 128
num_steps = 20
scheduler = RectifiedFlowScheduler(
sampler=("Uniform" if sampler.lower() == "uniform" else "LinearQuadratic")
)
latents = torch.randn(size=(batch_size, n_tokens, n_channels))
scheduler.set_timesteps(num_inference_steps=num_steps, samples=latents)
return scheduler, latents
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_default_behavior(sampler):
"""
Test the case of a single timestep from the list of timesteps.
"""
scheduler, latents = init_latents_and_scheduler(sampler)
for i, t in enumerate(scheduler.timesteps):
noise_pred = torch.randn_like(latents)
denoised_latents = scheduler.step(
noise_pred,
t,
latents,
return_dict=False,
)[0]
# Verify the denoising
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
dt = t - next_t
expected_denoised_latents = latents - dt * noise_pred
assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_per_token(sampler):
"""
Test the case of a timestep per token (from the list of timesteps).
Some tokens are set with timestep of 0.
"""
scheduler, latents = init_latents_and_scheduler(sampler)
batch_size, n_tokens = latents.shape[:2]
for i, t in enumerate(scheduler.timesteps):
timesteps = torch.full((batch_size, n_tokens), t)
timesteps[:, 0] = 0.0
noise_pred = torch.randn_like(latents)
denoised_latents = scheduler.step(
noise_pred,
timesteps,
latents,
return_dict=False,
)[0]
# Verify the denoising
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
next_timesteps = torch.full((batch_size, n_tokens), next_t)
dt = timesteps - next_timesteps
expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
assert torch.allclose(
denoised_latents[:, 1:], expected_denoised_latents[:, 1:], atol=1e-06
)
assert torch.allclose(denoised_latents[:, 0], latents[:, 0], atol=1e-06)
@pytest.mark.parametrize("sampler", ["LinearQuadratic", "Uniform"])
def test_scheduler_t_not_in_list(sampler):
"""
Test the case of a timestep per token NOT from the list of timesteps.
"""
scheduler, latents = init_latents_and_scheduler(sampler)
batch_size, n_tokens = latents.shape[:2]
for i in range(len(scheduler.timesteps)):
if i < len(scheduler.timesteps) - 1:
t = (scheduler.timesteps[i] + scheduler.timesteps[i + 1]) / 2
else:
t = scheduler.timesteps[i] / 2
timesteps = torch.full((batch_size, n_tokens), t)
noise_pred = torch.randn_like(latents)
denoised_latents = scheduler.step(
noise_pred,
timesteps,
latents,
return_dict=False,
)[0]
# Verify the denoising
next_t = scheduler.timesteps[i + 1] if i < len(scheduler.timesteps) - 1 else 0.0
next_timesteps = torch.full((batch_size, n_tokens), next_t)
dt = timesteps - next_timesteps
expected_denoised_latents = latents - dt.unsqueeze(-1) * noise_pred
assert torch.allclose(denoised_latents, expected_denoised_latents, atol=1e-06)