|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers import SD3Transformer2DModel |
|
|
from diffusers.utils.import_utils import is_xformers_available |
|
|
from diffusers.utils.testing_utils import ( |
|
|
enable_full_determinism, |
|
|
torch_device, |
|
|
) |
|
|
|
|
|
from ..test_modeling_common import ModelTesterMixin |
|
|
|
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
|
|
|
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): |
|
|
model_class = SD3Transformer2DModel |
|
|
main_input_name = "hidden_states" |
|
|
model_split_percents = [0.8, 0.8, 0.9] |
|
|
|
|
|
@property |
|
|
def dummy_input(self): |
|
|
batch_size = 2 |
|
|
num_channels = 4 |
|
|
height = width = embedding_dim = 32 |
|
|
pooled_embedding_dim = embedding_dim * 2 |
|
|
sequence_length = 154 |
|
|
|
|
|
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) |
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) |
|
|
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device) |
|
|
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) |
|
|
|
|
|
return { |
|
|
"hidden_states": hidden_states, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"pooled_projections": pooled_prompt_embeds, |
|
|
"timestep": timestep, |
|
|
} |
|
|
|
|
|
@property |
|
|
def input_shape(self): |
|
|
return (4, 32, 32) |
|
|
|
|
|
@property |
|
|
def output_shape(self): |
|
|
return (4, 32, 32) |
|
|
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
|
init_dict = { |
|
|
"sample_size": 32, |
|
|
"patch_size": 1, |
|
|
"in_channels": 4, |
|
|
"num_layers": 4, |
|
|
"attention_head_dim": 8, |
|
|
"num_attention_heads": 4, |
|
|
"caption_projection_dim": 32, |
|
|
"joint_attention_dim": 32, |
|
|
"pooled_projection_dim": 64, |
|
|
"out_channels": 4, |
|
|
"pos_embed_max_size": 96, |
|
|
"dual_attention_layers": (), |
|
|
"qk_norm": None, |
|
|
} |
|
|
inputs_dict = self.dummy_input |
|
|
return init_dict, inputs_dict |
|
|
|
|
|
@unittest.skipIf( |
|
|
torch_device != "cuda" or not is_xformers_available(), |
|
|
reason="XFormers attention is only available with CUDA and `xformers` installed", |
|
|
) |
|
|
def test_xformers_enable_works(self): |
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
model = self.model_class(**init_dict) |
|
|
|
|
|
model.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", ( |
|
|
"xformers is not enabled" |
|
|
) |
|
|
|
|
|
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") |
|
|
def test_set_attn_processor_for_determinism(self): |
|
|
pass |
|
|
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
|
expected_set = {"SD3Transformer2DModel"} |
|
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
|
|
|
|
|
|
|
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase): |
|
|
model_class = SD3Transformer2DModel |
|
|
main_input_name = "hidden_states" |
|
|
model_split_percents = [0.8, 0.8, 0.9] |
|
|
|
|
|
@property |
|
|
def dummy_input(self): |
|
|
batch_size = 2 |
|
|
num_channels = 4 |
|
|
height = width = embedding_dim = 32 |
|
|
pooled_embedding_dim = embedding_dim * 2 |
|
|
sequence_length = 154 |
|
|
|
|
|
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device) |
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) |
|
|
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device) |
|
|
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device) |
|
|
|
|
|
return { |
|
|
"hidden_states": hidden_states, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"pooled_projections": pooled_prompt_embeds, |
|
|
"timestep": timestep, |
|
|
} |
|
|
|
|
|
@property |
|
|
def input_shape(self): |
|
|
return (4, 32, 32) |
|
|
|
|
|
@property |
|
|
def output_shape(self): |
|
|
return (4, 32, 32) |
|
|
|
|
|
def prepare_init_args_and_inputs_for_common(self): |
|
|
init_dict = { |
|
|
"sample_size": 32, |
|
|
"patch_size": 1, |
|
|
"in_channels": 4, |
|
|
"num_layers": 4, |
|
|
"attention_head_dim": 8, |
|
|
"num_attention_heads": 4, |
|
|
"caption_projection_dim": 32, |
|
|
"joint_attention_dim": 32, |
|
|
"pooled_projection_dim": 64, |
|
|
"out_channels": 4, |
|
|
"pos_embed_max_size": 96, |
|
|
"dual_attention_layers": (0,), |
|
|
"qk_norm": "rms_norm", |
|
|
} |
|
|
inputs_dict = self.dummy_input |
|
|
return init_dict, inputs_dict |
|
|
|
|
|
@unittest.skipIf( |
|
|
torch_device != "cuda" or not is_xformers_available(), |
|
|
reason="XFormers attention is only available with CUDA and `xformers` installed", |
|
|
) |
|
|
def test_xformers_enable_works(self): |
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
model = self.model_class(**init_dict) |
|
|
|
|
|
model.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", ( |
|
|
"xformers is not enabled" |
|
|
) |
|
|
|
|
|
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") |
|
|
def test_set_attn_processor_for_determinism(self): |
|
|
pass |
|
|
|
|
|
def test_gradient_checkpointing_is_applied(self): |
|
|
expected_set = {"SD3Transformer2DModel"} |
|
|
super().test_gradient_checkpointing_is_applied(expected_set=expected_set) |
|
|
|
|
|
def test_skip_layers(self): |
|
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() |
|
|
model = self.model_class(**init_dict).to(torch_device) |
|
|
|
|
|
|
|
|
output_full = model(**inputs_dict).sample |
|
|
|
|
|
|
|
|
inputs_dict_with_skip = inputs_dict.copy() |
|
|
inputs_dict_with_skip["skip_layers"] = [0] |
|
|
output_skip = model(**inputs_dict_with_skip).sample |
|
|
|
|
|
|
|
|
self.assertFalse( |
|
|
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped" |
|
|
) |
|
|
|
|
|
|
|
|
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape") |
|
|
|