Spaces:
Runtime error
Runtime error
Update pipeline.py
Browse files- pipeline.py +7 -11
pipeline.py
CHANGED
|
@@ -268,30 +268,26 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 268 |
" the batch size of `prompt`."
|
| 269 |
)
|
| 270 |
|
| 271 |
-
|
| 272 |
prompt=negative_prompt,
|
| 273 |
device=device,
|
| 274 |
num_images_per_prompt=num_images_per_prompt,
|
| 275 |
)
|
| 276 |
|
| 277 |
-
|
| 278 |
prompt=negative_prompt_2,
|
| 279 |
device=device,
|
| 280 |
num_images_per_prompt=num_images_per_prompt,
|
| 281 |
max_sequence_length=max_sequence_length,
|
| 282 |
)
|
| 283 |
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
(0,
|
| 287 |
)
|
| 288 |
|
| 289 |
negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
|
| 290 |
|
| 291 |
-
|
| 292 |
-
negative_pooled_prompt_embeds = torch.cat(
|
| 293 |
-
[negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-1
|
| 294 |
-
)
|
| 295 |
|
| 296 |
if self.text_encoder is not None:
|
| 297 |
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
@@ -309,8 +305,8 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 309 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 310 |
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 311 |
|
| 312 |
-
|
| 313 |
-
|
| 314 |
|
| 315 |
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
| 316 |
|
|
|
|
| 268 |
" the batch size of `prompt`."
|
| 269 |
)
|
| 270 |
|
| 271 |
+
negative_pooled_prompt_embeds = self._get_clip_prompt_embeds(
|
| 272 |
prompt=negative_prompt,
|
| 273 |
device=device,
|
| 274 |
num_images_per_prompt=num_images_per_prompt,
|
| 275 |
)
|
| 276 |
|
| 277 |
+
t5_negative_prompt_embeds = self._get_t5_prompt_embeds(
|
| 278 |
prompt=negative_prompt_2,
|
| 279 |
device=device,
|
| 280 |
num_images_per_prompt=num_images_per_prompt,
|
| 281 |
max_sequence_length=max_sequence_length,
|
| 282 |
)
|
| 283 |
|
| 284 |
+
negative_pooled_prompt_embeds = torch.nn.functional.pad(
|
| 285 |
+
negative_pooled_prompt_embeds,
|
| 286 |
+
(0, t5_negative_prompt_embeds.shape[-1] - negative_pooled_prompt_embeds.shape[-1]),
|
| 287 |
)
|
| 288 |
|
| 289 |
negative_prompt_embeds = torch.cat([negative_clip_prompt_embed, t5_negative_prompt_embed], dim=-2)
|
| 290 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
if self.text_encoder is not None:
|
| 293 |
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
|
|
|
| 305 |
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 306 |
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 307 |
|
| 308 |
+
negative_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 309 |
+
negative_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
| 310 |
|
| 311 |
return prompt_embeds, pooled_prompt_embeds, text_ids, negative_prompt_embeds, negative_pooled_prompt_embeds
|
| 312 |
|