import random import unittest import numpy as np import torch from transformers import AutoTokenizer, T5EncoderModel from diffusers import AutoencoderKL, ChromaImg2ImgPipeline, ChromaTransformer2DModel, FlowMatchEulerDiscreteScheduler from ...testing_utils import floats_tensor, torch_device from ..test_pipelines_common import FluxIPAdapterTesterMixin, PipelineTesterMixin, check_qkv_fused_layers_exist class ChromaImg2ImgPipelineFastTests( unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin, ): pipeline_class = ChromaImg2ImgPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds"]) batch_params = frozenset(["prompt"]) # there is no xformers processor for Flux test_xformers_attention = False test_layerwise_casting = True test_group_offloading = True def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1): torch.manual_seed(0) transformer = ChromaTransformer2DModel( patch_size=1, in_channels=4, num_layers=num_layers, num_single_layers=num_single_layers, attention_head_dim=16, num_attention_heads=2, joint_attention_dim=32, axes_dims_rope=[4, 4, 8], approximator_hidden_dim=32, approximator_layers=1, approximator_num_channels=16, ) torch.manual_seed(0) text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") torch.manual_seed(0) vae = AutoencoderKL( sample_size=32, in_channels=3, out_channels=3, block_out_channels=(4,), layers_per_block=1, latent_channels=1, norm_num_groups=1, use_quant_conv=False, use_post_quant_conv=False, shift_factor=0.0609, scaling_factor=1.5035, ) scheduler = FlowMatchEulerDiscreteScheduler() return { "scheduler": scheduler, "text_encoder": text_encoder, "tokenizer": tokenizer, "transformer": transformer, "vae": vae, "image_encoder": None, "feature_extractor": None, } def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) if str(device).startswith("mps"): generator = torch.manual_seed(seed) else: generator = torch.Generator(device="cpu").manual_seed(seed) inputs = { "prompt": "A painting of a squirrel eating a burger", "image": image, "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, "height": 8, "width": 8, "max_sequence_length": 48, "strength": 0.8, "output_type": "np", } return inputs def test_chroma_different_prompts(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) output_same_prompt = pipe(**inputs).images[0] inputs = self.get_dummy_inputs(torch_device) inputs["prompt"] = "a different prompt" output_different_prompts = pipe(**inputs).images[0] max_diff = np.abs(output_same_prompt - output_different_prompts).max() # Outputs should be different here # For some reasons, they don't show large differences assert max_diff > 1e-6 def test_fused_qkv_projections(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components() pipe = self.pipeline_class(**components) pipe = pipe.to(device) pipe.set_progress_bar_config(disable=None) inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images original_image_slice = image[0, -3:, -3:, -1] # TODO (sayakpaul): will refactor this once `fuse_qkv_projections()` has been added # to the pipeline level. pipe.transformer.fuse_qkv_projections() self.assertTrue( check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]), ("Something wrong with the fused attention layers. Expected all the attention projections to be fused."), ) inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_fused = image[0, -3:, -3:, -1] pipe.transformer.unfuse_qkv_projections() inputs = self.get_dummy_inputs(device) image = pipe(**inputs).images image_slice_disabled = image[0, -3:, -3:, -1] assert np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3), ( "Fusion of QKV projections shouldn't affect the outputs." ) assert np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3), ( "Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled." ) assert np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2), ( "Original outputs should match when fused QKV projections are disabled." ) def test_chroma_image_output_shape(self): pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) inputs = self.get_dummy_inputs(torch_device) height_width_pairs = [(32, 32), (72, 57)] for height, width in height_width_pairs: expected_height = height - height % (pipe.vae_scale_factor * 2) expected_width = width - width % (pipe.vae_scale_factor * 2) inputs.update({"height": height, "width": width}) image = pipe(**inputs).images[0] output_height, output_width, _ = image.shape assert (output_height, output_width) == (expected_height, expected_width)