Update pipeline_stable_diffusion_xl_instantid.py
Browse files
pipeline_stable_diffusion_xl_instantid.py
CHANGED
|
@@ -217,41 +217,29 @@ class StableDiffusionXLInstantIDPipeline(StableDiffusionXLControlNetPipeline):
|
|
| 217 |
if isinstance(attn_processor, IPAttnProcessor):
|
| 218 |
attn_processor.scale = scale
|
| 219 |
|
| 220 |
-
def _encode_prompt_image_emb(
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
device,
|
| 225 |
-
do_classifier_free_guidance,
|
| 226 |
-
prompt_image_emb=None, # 5ème argument optionnel
|
| 227 |
-
):
|
| 228 |
-
# Cas où on ne fournit pas d'image_embeds : on désactive ce chemin
|
| 229 |
-
if image_embeds is None:
|
| 230 |
-
return None, None
|
| 231 |
-
|
| 232 |
-
# Dans ta version d’origine, ils assignent prompt_image_emb depuis image_embeds
|
| 233 |
-
if prompt_image_emb is None:
|
| 234 |
-
prompt_image_emb = image_embeds
|
| 235 |
-
|
| 236 |
-
# S'assurer que c'est un tensor
|
| 237 |
-
if not isinstance(prompt_image_emb, torch.Tensor):
|
| 238 |
-
prompt_image_emb = torch.tensor(prompt_image_emb, device=device)
|
| 239 |
else:
|
| 240 |
-
prompt_image_emb =
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
if bs != num_images_per_prompt:
|
| 245 |
-
prompt_image_emb = prompt_image_emb.repeat_interleave(num_images_per_prompt, dim=0)
|
| 246 |
-
|
| 247 |
-
# Classifier-free guidance : on crée un embedding négatif si besoin
|
| 248 |
if do_classifier_free_guidance:
|
| 249 |
-
|
| 250 |
-
prompt_image_emb = torch.cat([negative_prompt_image_emb, prompt_image_emb], dim=0)
|
| 251 |
else:
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
|
| 256 |
@torch.no_grad()
|
| 257 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
|
|
|
| 217 |
if isinstance(attn_processor, IPAttnProcessor):
|
| 218 |
attn_processor.scale = scale
|
| 219 |
|
| 220 |
+
def _encode_prompt_image_emb(self, prompt_image_emb, device, num_images_per_prompt, dtype, do_classifier_free_guidance):
|
| 221 |
+
|
| 222 |
+
if isinstance(prompt_image_emb, torch.Tensor):
|
| 223 |
+
prompt_image_emb = prompt_image_emb.clone().detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
else:
|
| 225 |
+
prompt_image_emb = torch.tensor(prompt_image_emb)
|
| 226 |
+
|
| 227 |
+
prompt_image_emb = prompt_image_emb.reshape([1, -1, self.image_proj_model_in_features])
|
| 228 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
if do_classifier_free_guidance:
|
| 230 |
+
prompt_image_emb = torch.cat([torch.zeros_like(prompt_image_emb), prompt_image_emb], dim=0)
|
|
|
|
| 231 |
else:
|
| 232 |
+
prompt_image_emb = torch.cat([prompt_image_emb], dim=0)
|
| 233 |
+
|
| 234 |
+
prompt_image_emb = prompt_image_emb.to(device=self.image_proj_model.latents.device,
|
| 235 |
+
dtype=self.image_proj_model.latents.dtype)
|
| 236 |
+
prompt_image_emb = self.image_proj_model(prompt_image_emb)
|
| 237 |
|
| 238 |
+
bs_embed, seq_len, _ = prompt_image_emb.shape
|
| 239 |
+
prompt_image_emb = prompt_image_emb.repeat(1, num_images_per_prompt, 1)
|
| 240 |
+
prompt_image_emb = prompt_image_emb.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
| 241 |
+
|
| 242 |
+
return prompt_image_emb.to(device=device, dtype=dtype)
|
| 243 |
|
| 244 |
@torch.no_grad()
|
| 245 |
@replace_example_docstring(EXAMPLE_DOC_STRING)
|