File size: 6,075 Bytes
44e6efe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import unittest
import numpy as np
import PIL.Image
import torch
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import (
AutoencoderKL,
FasterCacheConfig,
FlowMatchEulerDiscreteScheduler,
FluxKontextPipeline,
FluxTransformer2DModel,
)
from ...testing_utils import torch_device
from ..test_pipelines_common import (
FasterCacheTesterMixin,
FluxIPAdapterTesterMixin,
PipelineTesterMixin,
PyramidAttentionBroadcastTesterMixin,
)
class FluxKontextPipelineFastTests(
unittest.TestCase,
PipelineTesterMixin,
FluxIPAdapterTesterMixin,
PyramidAttentionBroadcastTesterMixin,
FasterCacheTesterMixin,
):
pipeline_class = FluxKontextPipeline
params = frozenset(
["image", "prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]
)
batch_params = frozenset(["image", "prompt"])
# there is no xformers processor for Flux
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
faster_cache_config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 901),
unconditional_batch_skip_range=2,
attention_weight_callback=lambda _: 0.5,
is_guidance_distilled=True,
)
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
torch.manual_seed(0)
transformer = FluxTransformer2DModel(
patch_size=1,
in_channels=4,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=16,
num_attention_heads=2,
joint_attention_dim=32,
pooled_projection_dim=32,
axes_dims_rope=[4, 4, 8],
)
clip_text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
hidden_act="gelu",
projection_dim=32,
)
torch.manual_seed(0)
text_encoder = CLIPTextModel(clip_text_encoder_config)
torch.manual_seed(0)
text_encoder_2 = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
tokenizer_2 = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=1,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0609,
scaling_factor=1.5035,
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {
"scheduler": scheduler,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"transformer": transformer,
"vae": vae,
"image_encoder": None,
"feature_extractor": None,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
image = PIL.Image.new("RGB", (32, 32), 0)
inputs = {
"image": image,
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 5.0,
"height": 8,
"width": 8,
"max_area": 8 * 8,
"max_sequence_length": 48,
"output_type": "np",
"_auto_resize": False,
}
return inputs
def test_flux_different_prompts(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
output_same_prompt = pipe(**inputs).images[0]
inputs = self.get_dummy_inputs(torch_device)
inputs["prompt_2"] = "a different prompt"
output_different_prompts = pipe(**inputs).images[0]
max_diff = np.abs(output_same_prompt - output_different_prompts).max()
# Outputs should be different here
# For some reasons, they don't show large differences
assert max_diff > 1e-6
def test_flux_image_output_shape(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
height_width_pairs = [(32, 32), (72, 57)]
for height, width in height_width_pairs:
expected_height = height - height % (pipe.vae_scale_factor * 2)
expected_width = width - width % (pipe.vae_scale_factor * 2)
inputs.update({"height": height, "width": width, "max_area": height * width})
image = pipe(**inputs).images[0]
output_height, output_width, _ = image.shape
assert (output_height, output_width) == (expected_height, expected_width)
def test_flux_true_cfg(self):
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
inputs = self.get_dummy_inputs(torch_device)
inputs.pop("generator")
no_true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
inputs["negative_prompt"] = "bad quality"
inputs["true_cfg_scale"] = 2.0
true_cfg_out = pipe(**inputs, generator=torch.manual_seed(0)).images[0]
assert not np.allclose(no_true_cfg_out, true_cfg_out)
|