Leteint commited on
Commit
81aeb88
·
verified ·
1 Parent(s): cae5178

Update pipeline_stable_diffusion_xl_instantid.py

Browse files
pipeline_stable_diffusion_xl_instantid.py CHANGED
@@ -217,41 +217,29 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
217
  if isinstance(attn_processor, IPAttnProcessor):
218
  attn_processor.scale = scale
219
 
220
- def _encode_prompt_image_emb(
221
- self,
222
- image_embeds,
223
- num_images_per_prompt,
224
- device,
225
- do_classifier_free_guidance,
226
- prompt_image_emb=None, # 5ème argument optionnel
227
- ):
228
- # Cas où on ne fournit pas d'image_embeds : on désactive ce chemin
229
- if image_embeds is None:
230
- return None, None
231
-
232
- # Dans ta version d’origine, ils assignent prompt_image_emb depuis image_embeds
233
- if prompt_image_emb is None:
234
- prompt_image_emb = image_embeds
235
-
236
- # S'assurer que c'est un tensor
237
- if not isinstance(prompt_image_emb, torch.Tensor):
238
- prompt_image_emb = torch.tensor(prompt_image_emb, device=device)
239
  else:
240
- prompt_image_emb = prompt_image_emb.to(device)
241
-
242
- # Répéter pour num_images_per_prompt
243
- bs = prompt_image_emb.shape[0]
244
- if bs != num_images_per_prompt:
245
- prompt_image_emb = prompt_image_emb.repeat_interleave(num_images_per_prompt, dim=0)
246
-
247
- # Classifier-free guidance : on crée un embedding négatif si besoin
248
  if do_classifier_free_guidance:
249
- negative_prompt_image_emb = torch.zeros_like(prompt_image_emb, device=device)
250
- prompt_image_emb = torch.cat([negative_prompt_image_emb, prompt_image_emb], dim=0)
251
  else:
252
- negative_prompt_image_emb = None
 
 
 
 
253
 
254
- return prompt_image_emb, negative_prompt_image_emb
 
 
 
 
255
 
256
  @torch.no_grad()
257
  @replace_example_docstring(EXAMPLE_DOC_STRING)
 
217
  if isinstance(attn_processor, IPAttnProcessor):
218
  attn_processor.scale = scale
219
 
220
+ def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
221
+
222
+ if isinstance(prompt_image_emb, torch.Tensor):
223
+ prompt_image_emb = prompt_image_emb.clone().detach()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  else:
225
+ prompt_image_emb = torch.tensor(prompt_image_emb)
226
+
227
+ prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
228
+
 
 
 
 
229
  if do_classifier_free_guidance:
230
+ prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
 
231
  else:
232
+ prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
233
+
234
+ prompt_image_emb = prompt_image_emb.to(device=self.image_proj_model.latents.device,
235
+ dtype=self.image_proj_model.latents.dtype)
236
+ prompt_image_emb = self.image_proj_model(prompt_image_emb)
237
 
238
+ bs_embed, seq_len, _ = prompt_image_emb.shape
239
+ prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
240
+ prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
241
+
242
+ return prompt_image_emb.to(device=device, dtype=dtype)
243
 
244
  @torch.no_grad()
245
  @replace_example_docstring(EXAMPLE_DOC_STRING)