| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import gc |
| | import os |
| | import unittest |
| |
|
| | import torch |
| |
|
| | from diffusers import ZImageTransformer2DModel |
| |
|
| | from ...testing_utils import IS_GITHUB_ACTIONS, torch_device |
| | from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin |
| |
|
| |
|
| | |
| | |
| | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" |
| | os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" |
| | torch.use_deterministic_algorithms(False) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| | if hasattr(torch.backends, "cuda"): |
| | torch.backends.cuda.matmul.allow_tf32 = False |
| |
|
| |
|
| | @unittest.skipIf( |
| | IS_GITHUB_ACTIONS, |
| | reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.", |
| | ) |
| | class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase): |
| | model_class = ZImageTransformer2DModel |
| | main_input_name = "x" |
| | |
| | model_split_percents = [0.9, 0.9, 0.9] |
| |
|
| | def prepare_dummy_input(self, height=16, width=16): |
| | batch_size = 1 |
| | num_channels = 16 |
| | embedding_dim = 16 |
| | sequence_length = 16 |
| |
|
| | hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)] |
| | encoder_hidden_states = [ |
| | torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size) |
| | ] |
| | timestep = torch.tensor([0.0]).to(torch_device) |
| |
|
| | return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep} |
| |
|
| | @property |
| | def dummy_input(self): |
| | return self.prepare_dummy_input() |
| |
|
| | @property |
| | def input_shape(self): |
| | return (4, 32, 32) |
| |
|
| | @property |
| | def output_shape(self): |
| | return (4, 32, 32) |
| |
|
| | def prepare_init_args_and_inputs_for_common(self): |
| | init_dict = { |
| | "all_patch_size": (2,), |
| | "all_f_patch_size": (1,), |
| | "in_channels": 16, |
| | "dim": 16, |
| | "n_layers": 1, |
| | "n_refiner_layers": 1, |
| | "n_heads": 1, |
| | "n_kv_heads": 2, |
| | "qk_norm": True, |
| | "cap_feat_dim": 16, |
| | "rope_theta": 256.0, |
| | "t_scale": 1000.0, |
| | "axes_dims": [8, 4, 4], |
| | "axes_lens": [256, 32, 32], |
| | } |
| | inputs_dict = self.dummy_input |
| | return init_dict, inputs_dict |
| |
|
| | def setUp(self): |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | torch.manual_seed(0) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(0) |
| |
|
| | def tearDown(self): |
| | super().tearDown() |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| | torch.cuda.synchronize() |
| | torch.manual_seed(0) |
| | if torch.cuda.is_available(): |
| | torch.cuda.manual_seed_all(0) |
| |
|
| | def test_gradient_checkpointing_is_applied(self): |
| | expected_set = {"ZImageTransformer2DModel"} |
| | super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
| |
|
| | @unittest.skip("Test is not supported for handling main inputs that are lists.") |
| | def test_training(self): |
| | super().test_training() |
| |
|
| | @unittest.skip("Test is not supported for handling main inputs that are lists.") |
| | def test_ema_training(self): |
| | super().test_ema_training() |
| |
|
| | @unittest.skip("Test is not supported for handling main inputs that are lists.") |
| | def test_effective_gradient_checkpointing(self): |
| | super().test_effective_gradient_checkpointing() |
| |
|
| | @unittest.skip( |
| | "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices." |
| | ) |
| | def test_layerwise_casting_training(self): |
| | super().test_layerwise_casting_training() |
| |
|
| | @unittest.skip("Test is not supported for handling main inputs that are lists.") |
| | def test_outputs_equivalence(self): |
| | super().test_outputs_equivalence() |
| |
|
| | @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") |
| | def test_group_offloading(self): |
| | super().test_group_offloading() |
| |
|
| | @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.") |
| | def test_group_offloading_with_disk(self): |
| | super().test_group_offloading_with_disk() |
| |
|
| |
|
| | class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase): |
| | model_class = ZImageTransformer2DModel |
| | different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)] |
| |
|
| | def prepare_init_args_and_inputs_for_common(self): |
| | return ZImageTransformerTests().prepare_init_args_and_inputs_for_common() |
| |
|
| | def prepare_dummy_input(self, height, width): |
| | return ZImageTransformerTests().prepare_dummy_input(height=height, width=width) |
| |
|
| | @unittest.skip( |
| | "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice." |
| | ) |
| | def test_torch_compile_recompilation_and_graph_break(self): |
| | super().test_torch_compile_recompilation_and_graph_break() |
| |
|
| | @unittest.skip("Fullgraph AoT is broken") |
| | def test_compile_works_with_aot(self): |
| | super().test_compile_works_with_aot() |
| |
|
| | @unittest.skip("Fullgraph is broken") |
| | def test_compile_on_different_shapes(self): |
| | super().test_compile_on_different_shapes() |
| |
|