Upload pipeline_stable_diffusion_migc.py
Browse files- pipeline_stable_diffusion_migc.py +110 -72
pipeline_stable_diffusion_migc.py
CHANGED
|
@@ -22,11 +22,9 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
|
|
| 22 |
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
| 23 |
from diffusers.utils.torch_utils import randn_tensor
|
| 24 |
from packaging import version
|
| 25 |
-
from scipy.ndimage import uniform_filter
|
| 26 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 27 |
|
| 28 |
-
|
| 29 |
-
from core.diffusion.migc.mich_arch import MIGC, NaiveFuser
|
| 30 |
|
| 31 |
logger = logging.get_logger(__name__)
|
| 32 |
|
|
@@ -51,7 +49,8 @@ class MIGCProcessor(AttnProcessor):
|
|
| 51 |
hidden_states: torch.Tensor,
|
| 52 |
encoder_hidden_states: torch.Tensor | None = None,
|
| 53 |
attention_mask: torch.Tensor | None = None,
|
| 54 |
-
|
|
|
|
| 55 |
bboxes: List[List[float]] = [],
|
| 56 |
ith: int = 0,
|
| 57 |
embeds_pooler: torch.Tensor | None = None,
|
|
@@ -62,14 +61,39 @@ class MIGCProcessor(AttnProcessor):
|
|
| 62 |
ca_scale: float | None = None,
|
| 63 |
ea_scale: float | None = None,
|
| 64 |
sac_scale: float | None = None,
|
|
|
|
|
|
|
|
|
|
| 65 |
):
|
| 66 |
batch_size, sequence_length, _ = hidden_states.shape
|
| 67 |
assert batch_size == 1 or batch_size == 2, (
|
| 68 |
"We currently only implement sampling with batch_size=1, and we will implement sampling with batch_size=N as soon as possible."
|
| 69 |
)
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
instance_num = len(bboxes)
|
| 74 |
|
| 75 |
if ith > MIGCsteps:
|
|
@@ -80,43 +104,67 @@ class MIGCProcessor(AttnProcessor):
|
|
| 80 |
|
| 81 |
is_cross = encoder_hidden_states is not None
|
| 82 |
|
| 83 |
-
# ori_hidden_states = hidden_states.clone()
|
| 84 |
-
|
| 85 |
# In this case, we need to use MIGC or naive_fuser, so
|
| 86 |
# 1. We concat prompt embeds and phrases embeds
|
| 87 |
# 2. we copy the hidden_states_cond (instance_num+1) times for QKV
|
| 88 |
if is_cross and not is_vanilla_cross:
|
|
|
|
| 89 |
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_phrases])
|
| 90 |
# print(encoder_hidden_states.shape)
|
| 91 |
hidden_states_uncond = hidden_states[[0], ...]
|
| 92 |
hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1)
|
| 93 |
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
# QKV Operation of Vanilla Self-Attention or Cross-Attention
|
| 96 |
query = attn.to_q(hidden_states)
|
| 97 |
|
| 98 |
if encoder_hidden_states is None:
|
| 99 |
encoder_hidden_states = hidden_states
|
|
|
|
|
|
|
| 100 |
|
| 101 |
key = attn.to_k(encoder_hidden_states)
|
| 102 |
value = attn.to_v(encoder_hidden_states)
|
| 103 |
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
|
| 108 |
hidden_states = F.scaled_dot_product_attention(
|
| 109 |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 110 |
)
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# linear proj
|
| 116 |
hidden_states = attn.to_out[0](hidden_states)
|
| 117 |
# dropout
|
| 118 |
hidden_states = attn.to_out[1](hidden_states)
|
| 119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
###### Self-Attention Results ######
|
| 121 |
if not is_cross:
|
| 122 |
return hidden_states
|
|
@@ -129,42 +177,12 @@ class MIGCProcessor(AttnProcessor):
|
|
| 129 |
# hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output.
|
| 130 |
hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C])
|
| 131 |
cond_ca_output = hidden_states[1:, ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280])
|
| 132 |
-
guidance_masks = []
|
| 133 |
-
in_box = []
|
| 134 |
-
# Construct Instance Guidance Mask
|
| 135 |
-
for bbox in bboxes:
|
| 136 |
-
guidance_mask = np.zeros((height, width))
|
| 137 |
-
w_min = int(width * bbox[0])
|
| 138 |
-
w_max = int(width * bbox[2])
|
| 139 |
-
h_min = int(height * bbox[1])
|
| 140 |
-
h_max = int(height * bbox[3])
|
| 141 |
-
guidance_mask[h_min:h_max, w_min:w_max] = 1.0
|
| 142 |
-
guidance_masks.append(guidance_mask[None, ...])
|
| 143 |
-
in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]])
|
| 144 |
-
|
| 145 |
-
# Construct Background Guidance Mask
|
| 146 |
-
sup_mask = get_sup_mask(guidance_masks)
|
| 147 |
-
supplement_mask = torch.from_numpy(sup_mask[None, ...])
|
| 148 |
-
supplement_mask = F.interpolate(supplement_mask, (height // 8, width // 8), mode="bilinear").float()
|
| 149 |
-
supplement_mask = supplement_mask.to(hidden_states.device) # (1, 1, H, W)
|
| 150 |
-
|
| 151 |
-
guidance_masks = np.concatenate(guidance_masks, axis=0)
|
| 152 |
-
guidance_masks = guidance_masks[None, ...]
|
| 153 |
-
guidance_masks = torch.from_numpy(guidance_masks).float().to(cond_ca_output.device)
|
| 154 |
-
guidance_masks = F.interpolate(
|
| 155 |
-
guidance_masks, (height // 8, width // 8), mode="bilinear"
|
| 156 |
-
) # (1, instance_num, H, W)
|
| 157 |
-
|
| 158 |
-
in_box = torch.from_numpy(np.array(in_box))[None, ...].float().to(cond_ca_output.device) # (1, instance_num, 4)
|
| 159 |
|
| 160 |
other_info = {}
|
| 161 |
other_info["image_token"] = hidden_states_cond[None, ...]
|
| 162 |
-
other_info["context"] = encoder_hidden_states[1:, ...]
|
| 163 |
other_info["box"] = in_box
|
| 164 |
other_info["context_pooler"] = embeds_pooler[:, None, :] # (instance_num, 1, 768)
|
| 165 |
other_info["supplement_mask"] = supplement_mask
|
| 166 |
-
other_info["attn2"] = None
|
| 167 |
-
other_info["attn"] = attn
|
| 168 |
other_info["height"] = height
|
| 169 |
other_info["width"] = width
|
| 170 |
other_info["ca_scale"] = ca_scale
|
|
@@ -326,7 +344,7 @@ class StableDiffusionMIGCPipeline(
|
|
| 326 |
scheduler: KarrasDiffusionSchedulers,
|
| 327 |
safety_checker: StableDiffusionSafetyChecker,
|
| 328 |
feature_extractor: CLIPImageProcessor,
|
| 329 |
-
image_encoder: CLIPVisionModelWithProjection = None,
|
| 330 |
requires_safety_checker: bool = True,
|
| 331 |
):
|
| 332 |
super().__init__()
|
|
@@ -419,7 +437,11 @@ class StableDiffusionMIGCPipeline(
|
|
| 419 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 420 |
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 421 |
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
def _register_migc_adapters(self, unet: UNet2DConditionModel):
|
| 425 |
for name, module in unet.named_modules():
|
|
@@ -448,7 +470,7 @@ class StableDiffusionMIGCPipeline(
|
|
| 448 |
device,
|
| 449 |
num_images_per_prompt,
|
| 450 |
do_classifier_free_guidance,
|
| 451 |
-
negative_prompt=None,
|
| 452 |
prompt_embeds: Optional[torch.Tensor] = None,
|
| 453 |
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 454 |
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
|
@@ -881,24 +903,19 @@ class StableDiffusionMIGCPipeline(
|
|
| 881 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 882 |
MIGCsteps=20,
|
| 883 |
NaiveFuserSteps=-1,
|
| 884 |
-
ca_scale=None,
|
| 885 |
-
ea_scale=None,
|
| 886 |
-
sac_scale=None,
|
| 887 |
-
aug_phase_with_and=False,
|
| 888 |
-
sa_preserve=False,
|
| 889 |
-
use_sa_preserve=False,
|
| 890 |
**kwargs,
|
| 891 |
):
|
| 892 |
r"""
|
| 893 |
The call function to the pipeline for generation.
|
| 894 |
|
| 895 |
Args:
|
| 896 |
-
prompt (`str
|
| 897 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 898 |
instead.
|
| 899 |
-
|
| 900 |
-
The list of the indexes in the prompt to layout. Defaults to None.
|
| 901 |
-
bboxes (Union[List[List[List[float]]], List[List[float]]], optional):
|
| 902 |
The bounding boxes of the indexes to maintain layout in the image. Defaults to None.
|
| 903 |
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 904 |
The height in pixels of the generated image.
|
|
@@ -1037,7 +1054,7 @@ class StableDiffusionMIGCPipeline(
|
|
| 1037 |
else:
|
| 1038 |
batch_size = prompt_embeds.shape[0]
|
| 1039 |
if batch_size > 1:
|
| 1040 |
-
raise NotImplementedError("Batch processing is not supported.")
|
| 1041 |
|
| 1042 |
device = self._execution_device
|
| 1043 |
|
|
@@ -1067,6 +1084,7 @@ class StableDiffusionMIGCPipeline(
|
|
| 1067 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 1068 |
# to avoid doing two forward passes
|
| 1069 |
if self.do_classifier_free_guidance:
|
|
|
|
| 1070 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 1071 |
|
| 1072 |
# 4. Prepare timesteps
|
|
@@ -1087,6 +1105,37 @@ class StableDiffusionMIGCPipeline(
|
|
| 1087 |
latents,
|
| 1088 |
)
|
| 1089 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1090 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1091 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1092 |
|
|
@@ -1098,20 +1147,6 @@ class StableDiffusionMIGCPipeline(
|
|
| 1098 |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 1099 |
).to(device=device, dtype=latents.dtype)
|
| 1100 |
|
| 1101 |
-
# 6.2 prepare MIGC guidance_mask
|
| 1102 |
-
guidance_mask = np.full((4, height // 8, width // 8), 1.0)
|
| 1103 |
-
|
| 1104 |
-
for bbox in bboxes:
|
| 1105 |
-
w_min = max(0, int(width * bbox[0] // 8) - 5)
|
| 1106 |
-
w_max = min(width, int(width * bbox[2] // 8) + 5)
|
| 1107 |
-
h_min = max(0, int(height * bbox[1] // 8) - 5)
|
| 1108 |
-
h_max = min(height, int(height * bbox[3] // 8) + 5)
|
| 1109 |
-
guidance_mask[:, h_min:h_max, w_min:w_max] = 0
|
| 1110 |
-
|
| 1111 |
-
kernal_size = 5
|
| 1112 |
-
guidance_mask = uniform_filter(guidance_mask, axes=(1, 2), size=kernal_size)
|
| 1113 |
-
guidance_mask = torch.from_numpy(guidance_mask).to(self.device).unsqueeze(0)
|
| 1114 |
-
|
| 1115 |
# 7. Denoising loop
|
| 1116 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1117 |
self._num_timesteps = len(timesteps)
|
|
@@ -1139,6 +1174,9 @@ class StableDiffusionMIGCPipeline(
|
|
| 1139 |
"ca_scale": ca_scale,
|
| 1140 |
"ea_scale": ea_scale,
|
| 1141 |
"sac_scale": sac_scale,
|
|
|
|
|
|
|
|
|
|
| 1142 |
}
|
| 1143 |
|
| 1144 |
noise_pred = self.unet(
|
|
|
|
| 22 |
from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
| 23 |
from diffusers.utils.torch_utils import randn_tensor
|
| 24 |
from packaging import version
|
|
|
|
| 25 |
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
| 26 |
|
| 27 |
+
from core.diffusion.migc.migc_archs import MIGC, NaiveFuser
|
|
|
|
| 28 |
|
| 29 |
logger = logging.get_logger(__name__)
|
| 30 |
|
|
|
|
| 49 |
hidden_states: torch.Tensor,
|
| 50 |
encoder_hidden_states: torch.Tensor | None = None,
|
| 51 |
attention_mask: torch.Tensor | None = None,
|
| 52 |
+
temb: torch.Tensor | None = None,
|
| 53 |
+
encoder_hidden_states_phrases: torch.Tensor | None = None,
|
| 54 |
bboxes: List[List[float]] = [],
|
| 55 |
ith: int = 0,
|
| 56 |
embeds_pooler: torch.Tensor | None = None,
|
|
|
|
| 61 |
ca_scale: float | None = None,
|
| 62 |
ea_scale: float | None = None,
|
| 63 |
sac_scale: float | None = None,
|
| 64 |
+
guidance_masks: torch.Tensor | None = None,
|
| 65 |
+
supplement_mask: torch.Tensor | None = None,
|
| 66 |
+
in_box: torch.Tensor | None = None,
|
| 67 |
):
|
| 68 |
batch_size, sequence_length, _ = hidden_states.shape
|
| 69 |
assert batch_size == 1 or batch_size == 2, (
|
| 70 |
"We currently only implement sampling with batch_size=1, and we will implement sampling with batch_size=N as soon as possible."
|
| 71 |
)
|
| 72 |
+
residual = hidden_states
|
| 73 |
+
if attn.spatial_norm is not None:
|
| 74 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 75 |
|
| 76 |
+
input_ndim = hidden_states.ndim
|
| 77 |
|
| 78 |
+
if input_ndim == 4:
|
| 79 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 80 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 81 |
+
|
| 82 |
+
batch_size, sequence_length, _ = (
|
| 83 |
+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
if attention_mask is not None:
|
| 87 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 88 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 89 |
+
# (batch, heads, source_length, target_length)
|
| 90 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 91 |
+
|
| 92 |
+
if attn.group_norm is not None:
|
| 93 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 94 |
+
|
| 95 |
+
##########
|
| 96 |
+
# Expand encoder_hidden_states with encoder_hidden_states_phrases
|
| 97 |
instance_num = len(bboxes)
|
| 98 |
|
| 99 |
if ith > MIGCsteps:
|
|
|
|
| 104 |
|
| 105 |
is_cross = encoder_hidden_states is not None
|
| 106 |
|
|
|
|
|
|
|
| 107 |
# In this case, we need to use MIGC or naive_fuser, so
|
| 108 |
# 1. We concat prompt embeds and phrases embeds
|
| 109 |
# 2. we copy the hidden_states_cond (instance_num+1) times for QKV
|
| 110 |
if is_cross and not is_vanilla_cross:
|
| 111 |
+
batch_size_phrases = encoder_hidden_states_phrases.shape[0]
|
| 112 |
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_phrases])
|
| 113 |
# print(encoder_hidden_states.shape)
|
| 114 |
hidden_states_uncond = hidden_states[[0], ...]
|
| 115 |
hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1)
|
| 116 |
hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
|
| 117 |
+
else:
|
| 118 |
+
batch_size_phrases = 0
|
| 119 |
+
##########
|
| 120 |
|
|
|
|
| 121 |
query = attn.to_q(hidden_states)
|
| 122 |
|
| 123 |
if encoder_hidden_states is None:
|
| 124 |
encoder_hidden_states = hidden_states
|
| 125 |
+
elif attn.norm_cross:
|
| 126 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 127 |
|
| 128 |
key = attn.to_k(encoder_hidden_states)
|
| 129 |
value = attn.to_v(encoder_hidden_states)
|
| 130 |
|
| 131 |
+
inner_dim = key.shape[-1]
|
| 132 |
+
head_dim = inner_dim // attn.heads
|
| 133 |
+
|
| 134 |
+
query = query.view(batch_size + batch_size_phrases, -1, attn.heads, head_dim).transpose(1, 2)
|
| 135 |
+
|
| 136 |
+
key = key.view(batch_size + batch_size_phrases, -1, attn.heads, head_dim).transpose(1, 2)
|
| 137 |
+
value = value.view(batch_size + batch_size_phrases, -1, attn.heads, head_dim).transpose(1, 2)
|
| 138 |
+
|
| 139 |
+
if attn.norm_q is not None:
|
| 140 |
+
query = attn.norm_q(query)
|
| 141 |
+
if attn.norm_k is not None:
|
| 142 |
+
key = attn.norm_k(key)
|
| 143 |
|
| 144 |
hidden_states = F.scaled_dot_product_attention(
|
| 145 |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 146 |
)
|
| 147 |
+
|
| 148 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(
|
| 149 |
+
batch_size + batch_size_phrases, -1, attn.heads * head_dim
|
| 150 |
+
)
|
| 151 |
+
hidden_states = hidden_states.to(query.dtype)
|
| 152 |
|
| 153 |
# linear proj
|
| 154 |
hidden_states = attn.to_out[0](hidden_states)
|
| 155 |
# dropout
|
| 156 |
hidden_states = attn.to_out[1](hidden_states)
|
| 157 |
|
| 158 |
+
if input_ndim == 4:
|
| 159 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 160 |
+
batch_size + batch_size_phrases, channel, height, width
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
if attn.residual_connection:
|
| 164 |
+
hidden_states = hidden_states + residual
|
| 165 |
+
|
| 166 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 167 |
+
|
| 168 |
###### Self-Attention Results ######
|
| 169 |
if not is_cross:
|
| 170 |
return hidden_states
|
|
|
|
| 177 |
# hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output.
|
| 178 |
hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C])
|
| 179 |
cond_ca_output = hidden_states[1:, ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
|
| 181 |
other_info = {}
|
| 182 |
other_info["image_token"] = hidden_states_cond[None, ...]
|
|
|
|
| 183 |
other_info["box"] = in_box
|
| 184 |
other_info["context_pooler"] = embeds_pooler[:, None, :] # (instance_num, 1, 768)
|
| 185 |
other_info["supplement_mask"] = supplement_mask
|
|
|
|
|
|
|
| 186 |
other_info["height"] = height
|
| 187 |
other_info["width"] = width
|
| 188 |
other_info["ca_scale"] = ca_scale
|
|
|
|
| 344 |
scheduler: KarrasDiffusionSchedulers,
|
| 345 |
safety_checker: StableDiffusionSafetyChecker,
|
| 346 |
feature_extractor: CLIPImageProcessor,
|
| 347 |
+
image_encoder: CLIPVisionModelWithProjection | None = None,
|
| 348 |
requires_safety_checker: bool = True,
|
| 349 |
):
|
| 350 |
super().__init__()
|
|
|
|
| 437 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 438 |
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
| 439 |
|
| 440 |
+
self.default_sample_size = (
|
| 441 |
+
self.unet.config.sample_size
|
| 442 |
+
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
| 443 |
+
else 64
|
| 444 |
+
)
|
| 445 |
|
| 446 |
def _register_migc_adapters(self, unet: UNet2DConditionModel):
|
| 447 |
for name, module in unet.named_modules():
|
|
|
|
| 470 |
device,
|
| 471 |
num_images_per_prompt,
|
| 472 |
do_classifier_free_guidance,
|
| 473 |
+
negative_prompt: str | List[str] | None = None,
|
| 474 |
prompt_embeds: Optional[torch.Tensor] = None,
|
| 475 |
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 476 |
pooled_prompt_embeds: Optional[torch.Tensor] = None,
|
|
|
|
| 903 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 904 |
MIGCsteps=20,
|
| 905 |
NaiveFuserSteps=-1,
|
| 906 |
+
ca_scale: float | None = None,
|
| 907 |
+
ea_scale: float | None = None,
|
| 908 |
+
sac_scale: float | None = None,
|
|
|
|
|
|
|
|
|
|
| 909 |
**kwargs,
|
| 910 |
):
|
| 911 |
r"""
|
| 912 |
The call function to the pipeline for generation.
|
| 913 |
|
| 914 |
Args:
|
| 915 |
+
prompt (`str`, *optional*):
|
| 916 |
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 917 |
instead.
|
| 918 |
+
bboxes (List[List[float]]], optional):
|
|
|
|
|
|
|
| 919 |
The bounding boxes of the indexes to maintain layout in the image. Defaults to None.
|
| 920 |
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 921 |
The height in pixels of the generated image.
|
|
|
|
| 1054 |
else:
|
| 1055 |
batch_size = prompt_embeds.shape[0]
|
| 1056 |
if batch_size > 1:
|
| 1057 |
+
raise NotImplementedError("Batch processing is not supported yet.")
|
| 1058 |
|
| 1059 |
device = self._execution_device
|
| 1060 |
|
|
|
|
| 1084 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 1085 |
# to avoid doing two forward passes
|
| 1086 |
if self.do_classifier_free_guidance:
|
| 1087 |
+
assert isinstance(negative_prompt_embeds, torch.Tensor)
|
| 1088 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 1089 |
|
| 1090 |
# 4. Prepare timesteps
|
|
|
|
| 1105 |
latents,
|
| 1106 |
)
|
| 1107 |
|
| 1108 |
+
# 5.1 Prepare guidance masks
|
| 1109 |
+
guidance_masks = []
|
| 1110 |
+
in_box = []
|
| 1111 |
+
# Construct Instance Guidance Mask
|
| 1112 |
+
for bbox in bboxes:
|
| 1113 |
+
guidance_mask = np.zeros((height, width))
|
| 1114 |
+
w_min = int(width * bbox[0])
|
| 1115 |
+
w_max = int(width * bbox[2])
|
| 1116 |
+
h_min = int(height * bbox[1])
|
| 1117 |
+
h_max = int(height * bbox[3])
|
| 1118 |
+
guidance_mask[h_min:h_max, w_min:w_max] = 1.0
|
| 1119 |
+
guidance_masks.append(guidance_mask[None, ...])
|
| 1120 |
+
in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]])
|
| 1121 |
+
|
| 1122 |
+
# Construct Background Guidance Mask
|
| 1123 |
+
sup_mask = get_sup_mask(guidance_masks)
|
| 1124 |
+
supplement_mask = torch.from_numpy(sup_mask[None, ...])
|
| 1125 |
+
supplement_mask = F.interpolate(supplement_mask, (height // 8, width // 8), mode="bilinear")
|
| 1126 |
+
supplement_mask = supplement_mask.to(device=device, dtype=self.unet.dtype) # (1, 1, H, W)
|
| 1127 |
+
|
| 1128 |
+
guidance_masks = np.concatenate(guidance_masks, axis=0)
|
| 1129 |
+
guidance_masks = guidance_masks[None, ...]
|
| 1130 |
+
guidance_masks = torch.from_numpy(guidance_masks).to(device=device, dtype=self.unet.dtype)
|
| 1131 |
+
guidance_masks = F.interpolate(
|
| 1132 |
+
guidance_masks, (height // 8, width // 8), mode="bilinear"
|
| 1133 |
+
) # (1, instance_num, H, W)
|
| 1134 |
+
|
| 1135 |
+
in_box = torch.from_numpy(np.array(in_box))[None, ...].to(
|
| 1136 |
+
device=device, dtype=self.unet.dtype
|
| 1137 |
+
) # (1, instance_num, 4)
|
| 1138 |
+
|
| 1139 |
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 1140 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 1141 |
|
|
|
|
| 1147 |
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
|
| 1148 |
).to(device=device, dtype=latents.dtype)
|
| 1149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1150 |
# 7. Denoising loop
|
| 1151 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 1152 |
self._num_timesteps = len(timesteps)
|
|
|
|
| 1174 |
"ca_scale": ca_scale,
|
| 1175 |
"ea_scale": ea_scale,
|
| 1176 |
"sac_scale": sac_scale,
|
| 1177 |
+
"guidance_masks": guidance_masks,
|
| 1178 |
+
"supplement_mask": supplement_mask,
|
| 1179 |
+
"in_box": in_box,
|
| 1180 |
}
|
| 1181 |
|
| 1182 |
noise_pred = self.unet(
|