thisiswooyeol commited on
Commit
5d03d1e
·
verified ·
1 Parent(s): 5feb9d1

Upload pipeline_stable_diffusion_migc.py

Browse files
Files changed (1) hide show
  1. pipeline_stable_diffusion_migc.py +110 -72
pipeline_stable_diffusion_migc.py CHANGED
@@ -22,11 +22,9 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
22
  from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
23
  from diffusers.utils.torch_utils import randn_tensor
24
  from packaging import version
25
- from scipy.ndimage import uniform_filter
26
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
27
 
28
- # from utils import load_utils
29
- from core.diffusion.migc.mich_arch import MIGC, NaiveFuser
30
 
31
  logger = logging.get_logger(__name__)
32
 
@@ -51,7 +49,8 @@ class MIGCProcessor(AttnProcessor):
51
  hidden_states: torch.Tensor,
52
  encoder_hidden_states: torch.Tensor | None = None,
53
  attention_mask: torch.Tensor | None = None,
54
- encoder_hidden_states_phrases=None,
 
55
  bboxes: List[List[float]] = [],
56
  ith: int = 0,
57
  embeds_pooler: torch.Tensor | None = None,
@@ -62,14 +61,39 @@ class MIGCProcessor(AttnProcessor):
62
  ca_scale: float | None = None,
63
  ea_scale: float | None = None,
64
  sac_scale: float | None = None,
 
 
 
65
  ):
66
  batch_size, sequence_length, _ = hidden_states.shape
67
  assert batch_size == 1 or batch_size == 2, (
68
  "We currently only implement sampling with batch_size=1, and we will implement sampling with batch_size=N as soon as possible."
69
  )
 
 
 
70
 
71
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  instance_num = len(bboxes)
74
 
75
  if ith > MIGCsteps:
@@ -80,43 +104,67 @@ class MIGCProcessor(AttnProcessor):
80
 
81
  is_cross = encoder_hidden_states is not None
82
 
83
- # ori_hidden_states = hidden_states.clone()
84
-
85
  # In this case, we need to use MIGC or naive_fuser, so
86
  # 1. We concat prompt embeds and phrases embeds
87
  # 2. we copy the hidden_states_cond (instance_num+1) times for QKV
88
  if is_cross and not is_vanilla_cross:
 
89
  encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_phrases])
90
  # print(encoder_hidden_states.shape)
91
  hidden_states_uncond = hidden_states[[0], ...]
92
  hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1)
93
  hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
 
 
 
94
 
95
- # QKV Operation of Vanilla Self-Attention or Cross-Attention
96
  query = attn.to_q(hidden_states)
97
 
98
  if encoder_hidden_states is None:
99
  encoder_hidden_states = hidden_states
 
 
100
 
101
  key = attn.to_k(encoder_hidden_states)
102
  value = attn.to_v(encoder_hidden_states)
103
 
104
- query = attn.head_to_batch_dim(query)
105
- key = attn.head_to_batch_dim(key)
106
- value = attn.head_to_batch_dim(value)
 
 
 
 
 
 
 
 
 
107
 
108
  hidden_states = F.scaled_dot_product_attention(
109
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
110
  )
111
- # attention_probs = attn.get_attention_scores(query, key, attention_mask) # 48 4096 77
112
- # hidden_states = torch.bmm(attention_probs, value)
113
- hidden_states = attn.batch_to_head_dim(hidden_states)
 
 
114
 
115
  # linear proj
116
  hidden_states = attn.to_out[0](hidden_states)
117
  # dropout
118
  hidden_states = attn.to_out[1](hidden_states)
119
 
 
 
 
 
 
 
 
 
 
 
120
  ###### Self-Attention Results ######
121
  if not is_cross:
122
  return hidden_states
@@ -129,42 +177,12 @@ class MIGCProcessor(AttnProcessor):
129
  # hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output.
130
  hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C])
131
  cond_ca_output = hidden_states[1:, ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280])
132
- guidance_masks = []
133
- in_box = []
134
- # Construct Instance Guidance Mask
135
- for bbox in bboxes:
136
- guidance_mask = np.zeros((height, width))
137
- w_min = int(width * bbox[0])
138
- w_max = int(width * bbox[2])
139
- h_min = int(height * bbox[1])
140
- h_max = int(height * bbox[3])
141
- guidance_mask[h_min:h_max, w_min:w_max] = 1.0
142
- guidance_masks.append(guidance_mask[None, ...])
143
- in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]])
144
-
145
- # Construct Background Guidance Mask
146
- sup_mask = get_sup_mask(guidance_masks)
147
- supplement_mask = torch.from_numpy(sup_mask[None, ...])
148
- supplement_mask = F.interpolate(supplement_mask, (height // 8, width // 8), mode="bilinear").float()
149
- supplement_mask = supplement_mask.to(hidden_states.device) # (1, 1, H, W)
150
-
151
- guidance_masks = np.concatenate(guidance_masks, axis=0)
152
- guidance_masks = guidance_masks[None, ...]
153
- guidance_masks = torch.from_numpy(guidance_masks).float().to(cond_ca_output.device)
154
- guidance_masks = F.interpolate(
155
- guidance_masks, (height // 8, width // 8), mode="bilinear"
156
- ) # (1, instance_num, H, W)
157
-
158
- in_box = torch.from_numpy(np.array(in_box))[None, ...].float().to(cond_ca_output.device) # (1, instance_num, 4)
159
 
160
  other_info = {}
161
  other_info["image_token"] = hidden_states_cond[None, ...]
162
- other_info["context"] = encoder_hidden_states[1:, ...]
163
  other_info["box"] = in_box
164
  other_info["context_pooler"] = embeds_pooler[:, None, :] # (instance_num, 1, 768)
165
  other_info["supplement_mask"] = supplement_mask
166
- other_info["attn2"] = None
167
- other_info["attn"] = attn
168
  other_info["height"] = height
169
  other_info["width"] = width
170
  other_info["ca_scale"] = ca_scale
@@ -326,7 +344,7 @@ class StableDiffusionMIGCPipeline(
326
  scheduler: KarrasDiffusionSchedulers,
327
  safety_checker: StableDiffusionSafetyChecker,
328
  feature_extractor: CLIPImageProcessor,
329
- image_encoder: CLIPVisionModelWithProjection = None,
330
  requires_safety_checker: bool = True,
331
  ):
332
  super().__init__()
@@ -419,7 +437,11 @@ class StableDiffusionMIGCPipeline(
419
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
420
  self.register_to_config(requires_safety_checker=requires_safety_checker)
421
 
422
- # self.embedding = {}
 
 
 
 
423
 
424
  def _register_migc_adapters(self, unet: UNet2DConditionModel):
425
  for name, module in unet.named_modules():
@@ -448,7 +470,7 @@ class StableDiffusionMIGCPipeline(
448
  device,
449
  num_images_per_prompt,
450
  do_classifier_free_guidance,
451
- negative_prompt=None,
452
  prompt_embeds: Optional[torch.Tensor] = None,
453
  negative_prompt_embeds: Optional[torch.Tensor] = None,
454
  pooled_prompt_embeds: Optional[torch.Tensor] = None,
@@ -881,24 +903,19 @@ class StableDiffusionMIGCPipeline(
881
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
882
  MIGCsteps=20,
883
  NaiveFuserSteps=-1,
884
- ca_scale=None,
885
- ea_scale=None,
886
- sac_scale=None,
887
- aug_phase_with_and=False,
888
- sa_preserve=False,
889
- use_sa_preserve=False,
890
  **kwargs,
891
  ):
892
  r"""
893
  The call function to the pipeline for generation.
894
 
895
  Args:
896
- prompt (`str` or `List[str]`, *optional*):
897
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
898
  instead.
899
- token_indices (Union[List[List[List[int]]], List[List[int]]], optional):
900
- The list of the indexes in the prompt to layout. Defaults to None.
901
- bboxes (Union[List[List[List[float]]], List[List[float]]], optional):
902
  The bounding boxes of the indexes to maintain layout in the image. Defaults to None.
903
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
904
  The height in pixels of the generated image.
@@ -1037,7 +1054,7 @@ class StableDiffusionMIGCPipeline(
1037
  else:
1038
  batch_size = prompt_embeds.shape[0]
1039
  if batch_size > 1:
1040
- raise NotImplementedError("Batch processing is not supported.")
1041
 
1042
  device = self._execution_device
1043
 
@@ -1067,6 +1084,7 @@ class StableDiffusionMIGCPipeline(
1067
  # Here we concatenate the unconditional and text embeddings into a single batch
1068
  # to avoid doing two forward passes
1069
  if self.do_classifier_free_guidance:
 
1070
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1071
 
1072
  # 4. Prepare timesteps
@@ -1087,6 +1105,37 @@ class StableDiffusionMIGCPipeline(
1087
  latents,
1088
  )
1089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1090
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1091
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1092
 
@@ -1098,20 +1147,6 @@ class StableDiffusionMIGCPipeline(
1098
  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1099
  ).to(device=device, dtype=latents.dtype)
1100
 
1101
- # 6.2 prepare MIGC guidance_mask
1102
- guidance_mask = np.full((4, height // 8, width // 8), 1.0)
1103
-
1104
- for bbox in bboxes:
1105
- w_min = max(0, int(width * bbox[0] // 8) - 5)
1106
- w_max = min(width, int(width * bbox[2] // 8) + 5)
1107
- h_min = max(0, int(height * bbox[1] // 8) - 5)
1108
- h_max = min(height, int(height * bbox[3] // 8) + 5)
1109
- guidance_mask[:, h_min:h_max, w_min:w_max] = 0
1110
-
1111
- kernal_size = 5
1112
- guidance_mask = uniform_filter(guidance_mask, axes=(1, 2), size=kernal_size)
1113
- guidance_mask = torch.from_numpy(guidance_mask).to(self.device).unsqueeze(0)
1114
-
1115
  # 7. Denoising loop
1116
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1117
  self._num_timesteps = len(timesteps)
@@ -1139,6 +1174,9 @@ class StableDiffusionMIGCPipeline(
1139
  "ca_scale": ca_scale,
1140
  "ea_scale": ea_scale,
1141
  "sac_scale": sac_scale,
 
 
 
1142
  }
1143
 
1144
  noise_pred = self.unet(
 
22
  from diffusers.utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
23
  from diffusers.utils.torch_utils import randn_tensor
24
  from packaging import version
 
25
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
26
 
27
+ from core.diffusion.migc.migc_archs import MIGC, NaiveFuser
 
28
 
29
  logger = logging.get_logger(__name__)
30
 
 
49
  hidden_states: torch.Tensor,
50
  encoder_hidden_states: torch.Tensor | None = None,
51
  attention_mask: torch.Tensor | None = None,
52
+ temb: torch.Tensor | None = None,
53
+ encoder_hidden_states_phrases: torch.Tensor | None = None,
54
  bboxes: List[List[float]] = [],
55
  ith: int = 0,
56
  embeds_pooler: torch.Tensor | None = None,
 
61
  ca_scale: float | None = None,
62
  ea_scale: float | None = None,
63
  sac_scale: float | None = None,
64
+ guidance_masks: torch.Tensor | None = None,
65
+ supplement_mask: torch.Tensor | None = None,
66
+ in_box: torch.Tensor | None = None,
67
  ):
68
  batch_size, sequence_length, _ = hidden_states.shape
69
  assert batch_size == 1 or batch_size == 2, (
70
  "We currently only implement sampling with batch_size=1, and we will implement sampling with batch_size=N as soon as possible."
71
  )
72
+ residual = hidden_states
73
+ if attn.spatial_norm is not None:
74
+ hidden_states = attn.spatial_norm(hidden_states, temb)
75
 
76
+ input_ndim = hidden_states.ndim
77
 
78
+ if input_ndim == 4:
79
+ batch_size, channel, height, width = hidden_states.shape
80
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
81
+
82
+ batch_size, sequence_length, _ = (
83
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
84
+ )
85
+
86
+ if attention_mask is not None:
87
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
88
+ # scaled_dot_product_attention expects attention_mask shape to be
89
+ # (batch, heads, source_length, target_length)
90
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
91
+
92
+ if attn.group_norm is not None:
93
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
94
+
95
+ ##########
96
+ # Expand encoder_hidden_states with encoder_hidden_states_phrases
97
  instance_num = len(bboxes)
98
 
99
  if ith > MIGCsteps:
 
104
 
105
  is_cross = encoder_hidden_states is not None
106
 
 
 
107
  # In this case, we need to use MIGC or naive_fuser, so
108
  # 1. We concat prompt embeds and phrases embeds
109
  # 2. we copy the hidden_states_cond (instance_num+1) times for QKV
110
  if is_cross and not is_vanilla_cross:
111
+ batch_size_phrases = encoder_hidden_states_phrases.shape[0]
112
  encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_phrases])
113
  # print(encoder_hidden_states.shape)
114
  hidden_states_uncond = hidden_states[[0], ...]
115
  hidden_states_cond = hidden_states[[1], ...].repeat(instance_num + 1, 1, 1)
116
  hidden_states = torch.cat([hidden_states_uncond, hidden_states_cond])
117
+ else:
118
+ batch_size_phrases = 0
119
+ ##########
120
 
 
121
  query = attn.to_q(hidden_states)
122
 
123
  if encoder_hidden_states is None:
124
  encoder_hidden_states = hidden_states
125
+ elif attn.norm_cross:
126
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
127
 
128
  key = attn.to_k(encoder_hidden_states)
129
  value = attn.to_v(encoder_hidden_states)
130
 
131
+ inner_dim = key.shape[-1]
132
+ head_dim = inner_dim // attn.heads
133
+
134
+ query = query.view(batch_size + batch_size_phrases, -1, attn.heads, head_dim).transpose(1, 2)
135
+
136
+ key = key.view(batch_size + batch_size_phrases, -1, attn.heads, head_dim).transpose(1, 2)
137
+ value = value.view(batch_size + batch_size_phrases, -1, attn.heads, head_dim).transpose(1, 2)
138
+
139
+ if attn.norm_q is not None:
140
+ query = attn.norm_q(query)
141
+ if attn.norm_k is not None:
142
+ key = attn.norm_k(key)
143
 
144
  hidden_states = F.scaled_dot_product_attention(
145
  query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
146
  )
147
+
148
+ hidden_states = hidden_states.transpose(1, 2).reshape(
149
+ batch_size + batch_size_phrases, -1, attn.heads * head_dim
150
+ )
151
+ hidden_states = hidden_states.to(query.dtype)
152
 
153
  # linear proj
154
  hidden_states = attn.to_out[0](hidden_states)
155
  # dropout
156
  hidden_states = attn.to_out[1](hidden_states)
157
 
158
+ if input_ndim == 4:
159
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
160
+ batch_size + batch_size_phrases, channel, height, width
161
+ )
162
+
163
+ if attn.residual_connection:
164
+ hidden_states = hidden_states + residual
165
+
166
+ hidden_states = hidden_states / attn.rescale_output_factor
167
+
168
  ###### Self-Attention Results ######
169
  if not is_cross:
170
  return hidden_states
 
177
  # hidden_states: torch.Size([1+1+instance_num, HW, C]), the first 1 is the uncond ca output, the second 1 is the global ca output.
178
  hidden_states_uncond = hidden_states[[0], ...] # torch.Size([1, HW, C])
179
  cond_ca_output = hidden_states[1:, ...].unsqueeze(0) # torch.Size([1, 1+instance_num, 5, 64, 1280])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  other_info = {}
182
  other_info["image_token"] = hidden_states_cond[None, ...]
 
183
  other_info["box"] = in_box
184
  other_info["context_pooler"] = embeds_pooler[:, None, :] # (instance_num, 1, 768)
185
  other_info["supplement_mask"] = supplement_mask
 
 
186
  other_info["height"] = height
187
  other_info["width"] = width
188
  other_info["ca_scale"] = ca_scale
 
344
  scheduler: KarrasDiffusionSchedulers,
345
  safety_checker: StableDiffusionSafetyChecker,
346
  feature_extractor: CLIPImageProcessor,
347
+ image_encoder: CLIPVisionModelWithProjection | None = None,
348
  requires_safety_checker: bool = True,
349
  ):
350
  super().__init__()
 
437
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
438
  self.register_to_config(requires_safety_checker=requires_safety_checker)
439
 
440
+ self.default_sample_size = (
441
+ self.unet.config.sample_size
442
+ if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
443
+ else 64
444
+ )
445
 
446
  def _register_migc_adapters(self, unet: UNet2DConditionModel):
447
  for name, module in unet.named_modules():
 
470
  device,
471
  num_images_per_prompt,
472
  do_classifier_free_guidance,
473
+ negative_prompt: str | List[str] | None = None,
474
  prompt_embeds: Optional[torch.Tensor] = None,
475
  negative_prompt_embeds: Optional[torch.Tensor] = None,
476
  pooled_prompt_embeds: Optional[torch.Tensor] = None,
 
903
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
904
  MIGCsteps=20,
905
  NaiveFuserSteps=-1,
906
+ ca_scale: float | None = None,
907
+ ea_scale: float | None = None,
908
+ sac_scale: float | None = None,
 
 
 
909
  **kwargs,
910
  ):
911
  r"""
912
  The call function to the pipeline for generation.
913
 
914
  Args:
915
+ prompt (`str`, *optional*):
916
  The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
917
  instead.
918
+ bboxes (List[List[float]]], optional):
 
 
919
  The bounding boxes of the indexes to maintain layout in the image. Defaults to None.
920
  height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
921
  The height in pixels of the generated image.
 
1054
  else:
1055
  batch_size = prompt_embeds.shape[0]
1056
  if batch_size > 1:
1057
+ raise NotImplementedError("Batch processing is not supported yet.")
1058
 
1059
  device = self._execution_device
1060
 
 
1084
  # Here we concatenate the unconditional and text embeddings into a single batch
1085
  # to avoid doing two forward passes
1086
  if self.do_classifier_free_guidance:
1087
+ assert isinstance(negative_prompt_embeds, torch.Tensor)
1088
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
1089
 
1090
  # 4. Prepare timesteps
 
1105
  latents,
1106
  )
1107
 
1108
+ # 5.1 Prepare guidance masks
1109
+ guidance_masks = []
1110
+ in_box = []
1111
+ # Construct Instance Guidance Mask
1112
+ for bbox in bboxes:
1113
+ guidance_mask = np.zeros((height, width))
1114
+ w_min = int(width * bbox[0])
1115
+ w_max = int(width * bbox[2])
1116
+ h_min = int(height * bbox[1])
1117
+ h_max = int(height * bbox[3])
1118
+ guidance_mask[h_min:h_max, w_min:w_max] = 1.0
1119
+ guidance_masks.append(guidance_mask[None, ...])
1120
+ in_box.append([bbox[0], bbox[2], bbox[1], bbox[3]])
1121
+
1122
+ # Construct Background Guidance Mask
1123
+ sup_mask = get_sup_mask(guidance_masks)
1124
+ supplement_mask = torch.from_numpy(sup_mask[None, ...])
1125
+ supplement_mask = F.interpolate(supplement_mask, (height // 8, width // 8), mode="bilinear")
1126
+ supplement_mask = supplement_mask.to(device=device, dtype=self.unet.dtype) # (1, 1, H, W)
1127
+
1128
+ guidance_masks = np.concatenate(guidance_masks, axis=0)
1129
+ guidance_masks = guidance_masks[None, ...]
1130
+ guidance_masks = torch.from_numpy(guidance_masks).to(device=device, dtype=self.unet.dtype)
1131
+ guidance_masks = F.interpolate(
1132
+ guidance_masks, (height // 8, width // 8), mode="bilinear"
1133
+ ) # (1, instance_num, H, W)
1134
+
1135
+ in_box = torch.from_numpy(np.array(in_box))[None, ...].to(
1136
+ device=device, dtype=self.unet.dtype
1137
+ ) # (1, instance_num, 4)
1138
+
1139
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1140
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1141
 
 
1147
  guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1148
  ).to(device=device, dtype=latents.dtype)
1149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1150
  # 7. Denoising loop
1151
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1152
  self._num_timesteps = len(timesteps)
 
1174
  "ca_scale": ca_scale,
1175
  "ea_scale": ea_scale,
1176
  "sac_scale": sac_scale,
1177
+ "guidance_masks": guidance_masks,
1178
+ "supplement_mask": supplement_mask,
1179
+ "in_box": in_box,
1180
  }
1181
 
1182
  noise_pred = self.unet(