| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import unittest |
| |
|
| | import torch |
| |
|
| | from diffusers import LuminaNextDiT2DModel |
| |
|
| | from ...testing_utils import ( |
| | enable_full_determinism, |
| | torch_device, |
| | ) |
| | from ..test_modeling_common import ModelTesterMixin |
| |
|
| |
|
| | enable_full_determinism() |
| |
|
| |
|
| | class LuminaNextDiT2DModelTransformerTests(ModelTesterMixin, unittest.TestCase): |
| | model_class = LuminaNextDiT2DModel |
| | main_input_name = "hidden_states" |
| | uses_custom_attn_processor = True |
| |
|
| | @property |
| | def dummy_input(self): |
| | """ |
| | Args: |
| | None |
| | Returns: |
| | Dict: Dictionary of dummy input tensors |
| | """ |
| | batch_size = 2 |
| | num_channels = 4 |
| | height = width = 16 |
| | embedding_dim = 32 |
| | sequence_length = 16 |
| |
|
| | hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) |
| | encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) |
| | timestep = torch.rand(size=(batch_size,)).to(torch_device) |
| | encoder_mask = torch.randn(size=(batch_size, sequence_length)).to(torch_device) |
| | image_rotary_emb = torch.randn((384, 384, 4)).to(torch_device) |
| |
|
| | return { |
| | "hidden_states": hidden_states, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | "timestep": timestep, |
| | "encoder_mask": encoder_mask, |
| | "image_rotary_emb": image_rotary_emb, |
| | "cross_attention_kwargs": {}, |
| | } |
| |
|
| | @property |
| | def input_shape(self): |
| | """ |
| | Args: |
| | None |
| | Returns: |
| | Tuple: (int, int, int) |
| | """ |
| | return (4, 16, 16) |
| |
|
| | @property |
| | def output_shape(self): |
| | """ |
| | Args: |
| | None |
| | Returns: |
| | Tuple: (int, int, int) |
| | """ |
| | return (4, 16, 16) |
| |
|
| | def prepare_init_args_and_inputs_for_common(self): |
| | """ |
| | Args: |
| | None |
| | |
| | Returns: |
| | Tuple: (Dict, Dict) |
| | """ |
| | init_dict = { |
| | "sample_size": 16, |
| | "patch_size": 2, |
| | "in_channels": 4, |
| | "hidden_size": 24, |
| | "num_layers": 2, |
| | "num_attention_heads": 3, |
| | "num_kv_heads": 1, |
| | "multiple_of": 16, |
| | "ffn_dim_multiplier": None, |
| | "norm_eps": 1e-5, |
| | "learn_sigma": False, |
| | "qk_norm": True, |
| | "cross_attention_dim": 32, |
| | "scaling_factor": 1.0, |
| | } |
| |
|
| | inputs_dict = self.dummy_input |
| | return init_dict, inputs_dict |
| |
|