File size: 6,507 Bytes
ac2243f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import gc
import os
import unittest

import torch

from diffusers import ZImageTransformer2DModel

from ...testing_utils import IS_GITHUB_ACTIONS, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin


# Z-Image requires torch.use_deterministic_algorithms(False) due to complex64 RoPE operations
# Cannot use enable_full_determinism() which sets it to True
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
torch.use_deterministic_algorithms(False)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if hasattr(torch.backends, "cuda"):
    torch.backends.cuda.matmul.allow_tf32 = False


@unittest.skipIf(
    IS_GITHUB_ACTIONS,
    reason="Skipping test-suite inside the CI because the model has `torch.empty()` inside of it during init and we don't have a clear way to override it in the modeling tests.",
)
class ZImageTransformerTests(ModelTesterMixin, unittest.TestCase):
    model_class = ZImageTransformer2DModel
    main_input_name = "x"
    # We override the items here because the transformer under consideration is small.
    model_split_percents = [0.9, 0.9, 0.9]

    def prepare_dummy_input(self, height=16, width=16):
        batch_size = 1
        num_channels = 16
        embedding_dim = 16
        sequence_length = 16

        hidden_states = [torch.randn((num_channels, 1, height, width)).to(torch_device) for _ in range(batch_size)]
        encoder_hidden_states = [
            torch.randn((sequence_length, embedding_dim)).to(torch_device) for _ in range(batch_size)
        ]
        timestep = torch.tensor([0.0]).to(torch_device)

        return {"x": hidden_states, "cap_feats": encoder_hidden_states, "t": timestep}

    @property
    def dummy_input(self):
        return self.prepare_dummy_input()

    @property
    def input_shape(self):
        return (4, 32, 32)

    @property
    def output_shape(self):
        return (4, 32, 32)

    def prepare_init_args_and_inputs_for_common(self):
        init_dict = {
            "all_patch_size": (2,),
            "all_f_patch_size": (1,),
            "in_channels": 16,
            "dim": 16,
            "n_layers": 1,
            "n_refiner_layers": 1,
            "n_heads": 1,
            "n_kv_heads": 2,
            "qk_norm": True,
            "cap_feat_dim": 16,
            "rope_theta": 256.0,
            "t_scale": 1000.0,
            "axes_dims": [8, 4, 4],
            "axes_lens": [256, 32, 32],
        }
        inputs_dict = self.dummy_input
        return init_dict, inputs_dict

    def setUp(self):
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

    def tearDown(self):
        super().tearDown()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        torch.manual_seed(0)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(0)

    def test_gradient_checkpointing_is_applied(self):
        expected_set = {"ZImageTransformer2DModel"}
        super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

    @unittest.skip("Test is not supported for handling main inputs that are lists.")
    def test_training(self):
        super().test_training()

    @unittest.skip("Test is not supported for handling main inputs that are lists.")
    def test_ema_training(self):
        super().test_ema_training()

    @unittest.skip("Test is not supported for handling main inputs that are lists.")
    def test_effective_gradient_checkpointing(self):
        super().test_effective_gradient_checkpointing()

    @unittest.skip(
        "Test needs to be revisited. But we need to ensure `x_pad_token` and `cap_pad_token` are cast to the same dtype as the destination tensor before they are assigned to the padding indices."
    )
    def test_layerwise_casting_training(self):
        super().test_layerwise_casting_training()

    @unittest.skip("Test is not supported for handling main inputs that are lists.")
    def test_outputs_equivalence(self):
        super().test_outputs_equivalence()

    @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
    def test_group_offloading(self):
        super().test_group_offloading()

    @unittest.skip("Test will pass if we change to deterministic values instead of empty in the DiT.")
    def test_group_offloading_with_disk(self):
        super().test_group_offloading_with_disk()


class ZImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
    model_class = ZImageTransformer2DModel
    different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]

    def prepare_init_args_and_inputs_for_common(self):
        return ZImageTransformerTests().prepare_init_args_and_inputs_for_common()

    def prepare_dummy_input(self, height, width):
        return ZImageTransformerTests().prepare_dummy_input(height=height, width=width)

    @unittest.skip(
        "The repeated block in this model is ZImageTransformerBlock, which is used for noise_refiner, context_refiner, and layers. As a consequence of this, the inputs recorded for the block would vary during compilation and full compilation with fullgraph=True would trigger recompilation at least thrice."
    )
    def test_torch_compile_recompilation_and_graph_break(self):
        super().test_torch_compile_recompilation_and_graph_break()

    @unittest.skip("Fullgraph AoT is broken")
    def test_compile_works_with_aot(self):
        super().test_compile_works_with_aot()

    @unittest.skip("Fullgraph is broken")
    def test_compile_on_different_shapes(self):
        super().test_compile_on_different_shapes()