| # coding=utf-8 | |
| # Copyright 2025 HuggingFace Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import unittest | |
| import torch | |
| from diffusers import LTX2VideoTransformer3DModel | |
| from ...testing_utils import enable_full_determinism, torch_device | |
| from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin | |
| enable_full_determinism() | |
| class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase): | |
| model_class = LTX2VideoTransformer3DModel | |
| main_input_name = "hidden_states" | |
| uses_custom_attn_processor = True | |
| def dummy_input(self): | |
| # Common | |
| batch_size = 2 | |
| # Video | |
| num_frames = 2 | |
| num_channels = 4 | |
| height = 16 | |
| width = 16 | |
| # Audio | |
| audio_num_frames = 9 | |
| audio_num_channels = 2 | |
| num_mel_bins = 2 | |
| # Text | |
| embedding_dim = 16 | |
| sequence_length = 16 | |
| hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device) | |
| audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to( | |
| torch_device | |
| ) | |
| encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) | |
| audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) | |
| encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device) | |
| timestep = torch.rand((batch_size,)).to(torch_device) * 1000 | |
| return { | |
| "hidden_states": hidden_states, | |
| "audio_hidden_states": audio_hidden_states, | |
| "encoder_hidden_states": encoder_hidden_states, | |
| "audio_encoder_hidden_states": audio_encoder_hidden_states, | |
| "timestep": timestep, | |
| "encoder_attention_mask": encoder_attention_mask, | |
| "num_frames": num_frames, | |
| "height": height, | |
| "width": width, | |
| "audio_num_frames": audio_num_frames, | |
| "fps": 25.0, | |
| } | |
| def input_shape(self): | |
| return (512, 4) | |
| def output_shape(self): | |
| return (512, 4) | |
| def prepare_init_args_and_inputs_for_common(self): | |
| init_dict = { | |
| "in_channels": 4, | |
| "out_channels": 4, | |
| "patch_size": 1, | |
| "patch_size_t": 1, | |
| "num_attention_heads": 2, | |
| "attention_head_dim": 8, | |
| "cross_attention_dim": 16, | |
| "audio_in_channels": 4, | |
| "audio_out_channels": 4, | |
| "audio_num_attention_heads": 2, | |
| "audio_attention_head_dim": 4, | |
| "audio_cross_attention_dim": 8, | |
| "num_layers": 2, | |
| "qk_norm": "rms_norm_across_heads", | |
| "caption_channels": 16, | |
| "rope_double_precision": False, | |
| } | |
| inputs_dict = self.dummy_input | |
| return init_dict, inputs_dict | |
| def test_gradient_checkpointing_is_applied(self): | |
| expected_set = {"LTX2VideoTransformer3DModel"} | |
| super().test_gradient_checkpointing_is_applied(expected_set=expected_set) | |
| # def test_ltx2_consistency(self, seed=0, dtype=torch.float32): | |
| # torch.manual_seed(seed) | |
| # init_dict, _ = self.prepare_init_args_and_inputs_for_common() | |
| # # Calculate dummy inputs in a custom manner to ensure compatibility with original code | |
| # batch_size = 2 | |
| # num_frames = 9 | |
| # latent_frames = 2 | |
| # text_embedding_dim = 16 | |
| # text_seq_len = 16 | |
| # fps = 25.0 | |
| # sampling_rate = 16000.0 | |
| # hop_length = 160.0 | |
| # sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000 | |
| # timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device) | |
| # num_channels = 4 | |
| # latent_height = 4 | |
| # latent_width = 4 | |
| # hidden_states = torch.randn( | |
| # (batch_size, num_channels, latent_frames, latent_height, latent_width), | |
| # generator=torch.manual_seed(seed), | |
| # dtype=dtype, | |
| # device="cpu", | |
| # ) | |
| # # Patchify video latents (with patch_size (1, 1, 1)) | |
| # hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1) | |
| # hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3) | |
| # encoder_hidden_states = torch.randn( | |
| # (batch_size, text_seq_len, text_embedding_dim), | |
| # generator=torch.manual_seed(seed), | |
| # dtype=dtype, | |
| # device="cpu", | |
| # ) | |
| # audio_num_channels = 2 | |
| # num_mel_bins = 2 | |
| # latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps)) | |
| # audio_hidden_states = torch.randn( | |
| # (batch_size, audio_num_channels, latent_length, num_mel_bins), | |
| # generator=torch.manual_seed(seed), | |
| # dtype=dtype, | |
| # device="cpu", | |
| # ) | |
| # # Patchify audio latents | |
| # audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3) | |
| # audio_encoder_hidden_states = torch.randn( | |
| # (batch_size, text_seq_len, text_embedding_dim), | |
| # generator=torch.manual_seed(seed), | |
| # dtype=dtype, | |
| # device="cpu", | |
| # ) | |
| # inputs_dict = { | |
| # "hidden_states": hidden_states.to(device=torch_device), | |
| # "audio_hidden_states": audio_hidden_states.to(device=torch_device), | |
| # "encoder_hidden_states": encoder_hidden_states.to(device=torch_device), | |
| # "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device), | |
| # "timestep": timestep, | |
| # "num_frames": latent_frames, | |
| # "height": latent_height, | |
| # "width": latent_width, | |
| # "audio_num_frames": num_frames, | |
| # "fps": 25.0, | |
| # } | |
| # model = self.model_class.from_pretrained( | |
| # "diffusers-internal-dev/dummy-ltx2", | |
| # subfolder="transformer", | |
| # device_map="cpu", | |
| # ) | |
| # # torch.manual_seed(seed) | |
| # # model = self.model_class(**init_dict) | |
| # model.to(torch_device) | |
| # model.eval() | |
| # with attention_backend("native"): | |
| # with torch.no_grad(): | |
| # output = model(**inputs_dict) | |
| # video_output, audio_output = output.to_tuple() | |
| # self.assertIsNotNone(video_output) | |
| # self.assertIsNotNone(audio_output) | |
| # # input & output have to have the same shape | |
| # video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels) | |
| # self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match") | |
| # audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins) | |
| # self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match") | |
| # # Check against expected slice | |
| # # fmt: off | |
| # video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676]) | |
| # audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692]) | |
| # # fmt: on | |
| # video_output_flat = video_output.cpu().flatten().float() | |
| # video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]]) | |
| # self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4)) | |
| # audio_output_flat = audio_output.cpu().flatten().float() | |
| # audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]]) | |
| # self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4)) | |
| class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): | |
| model_class = LTX2VideoTransformer3DModel | |
| def prepare_init_args_and_inputs_for_common(self): | |
| return LTX2TransformerTests().prepare_init_args_and_inputs_for_common() | |