Spaces:
Runtime error
Runtime error
Update pipeline.py
Browse files- pipeline.py +56 -14
pipeline.py
CHANGED
|
@@ -100,6 +100,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 100 |
def _get_t5_prompt_embeds(
|
| 101 |
self,
|
| 102 |
prompt: Union[str, List[str]] = None,
|
|
|
|
| 103 |
num_images_per_prompt: int = 1,
|
| 104 |
max_sequence_length: int = 512,
|
| 105 |
device: Optional[torch.device] = None,
|
|
@@ -156,6 +157,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 156 |
|
| 157 |
text_inputs = self.tokenizer(
|
| 158 |
prompt,
|
|
|
|
| 159 |
padding="max_length",
|
| 160 |
max_length=self.tokenizer_max_length,
|
| 161 |
truncation=True,
|
|
@@ -188,14 +190,18 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 188 |
self,
|
| 189 |
prompt: Union[str, List[str]],
|
| 190 |
prompt_2: Union[str, List[str]],
|
|
|
|
| 191 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 192 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 193 |
device: Optional[torch.device] = None,
|
| 194 |
num_images_per_prompt: int = 1,
|
| 195 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
| 196 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
|
|
| 197 |
max_sequence_length: int = 512,
|
| 198 |
lora_scale: Optional[float] = None,
|
|
|
|
| 199 |
):
|
| 200 |
r"""
|
| 201 |
|
|
@@ -232,7 +238,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 232 |
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 233 |
|
| 234 |
prompt = [prompt] if isinstance(prompt, str) else prompt
|
| 235 |
-
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
| 236 |
|
| 237 |
if prompt_embeds is None:
|
| 238 |
prompt_2 = prompt_2 or prompt
|
|
@@ -251,16 +256,6 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 251 |
device=device,
|
| 252 |
)
|
| 253 |
|
| 254 |
-
if self.text_encoder is not None:
|
| 255 |
-
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 256 |
-
# Retrieve the original scale by scaling back the LoRA layers
|
| 257 |
-
unscale_lora_layers(self.text_encoder, lora_scale)
|
| 258 |
-
|
| 259 |
-
if self.text_encoder_2 is not None:
|
| 260 |
-
if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 261 |
-
# Retrieve the original scale by scaling back the LoRA layers
|
| 262 |
-
unscale_lora_layers(self.text_encoder_2, lora_scale)
|
| 263 |
-
|
| 264 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 265 |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 266 |
|
|
@@ -270,9 +265,10 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 270 |
self,
|
| 271 |
prompt,
|
| 272 |
prompt_2,
|
| 273 |
-
negative_prompt,
|
| 274 |
height,
|
| 275 |
width,
|
|
|
|
| 276 |
prompt_embeds=None,
|
| 277 |
pooled_prompt_embeds=None,
|
| 278 |
callback_on_step_end_tensor_inputs=None,
|
|
@@ -311,10 +307,56 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 311 |
raise ValueError(
|
| 312 |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 313 |
)
|
|
|
|
|
|
|
| 314 |
|
| 315 |
if max_sequence_length is not None and max_sequence_length > 512:
|
| 316 |
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
|
|
|
|
|
|
|
| 318 |
@staticmethod
|
| 319 |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 320 |
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
|
@@ -437,7 +479,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 437 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 438 |
height: Optional[int] = None,
|
| 439 |
width: Optional[int] = None,
|
| 440 |
-
negative_prompt: Union[str, List[str]] = None,
|
| 441 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 442 |
num_inference_steps: int = 4,
|
| 443 |
timesteps: List[int] = None,
|
|
@@ -457,7 +499,7 @@ class FluxWithCFGPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
|
|
| 457 |
):
|
| 458 |
height = height or self.default_sample_size * self.vae_scale_factor
|
| 459 |
width = width or self.default_sample_size * self.vae_scale_factor
|
| 460 |
-
|
| 461 |
# 1. Check inputs
|
| 462 |
self.check_inputs(
|
| 463 |
prompt,
|
|
|
|
| 100 |
def _get_t5_prompt_embeds(
|
| 101 |
self,
|
| 102 |
prompt: Union[str, List[str]] = None,
|
| 103 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 104 |
num_images_per_prompt: int = 1,
|
| 105 |
max_sequence_length: int = 512,
|
| 106 |
device: Optional[torch.device] = None,
|
|
|
|
| 157 |
|
| 158 |
text_inputs = self.tokenizer(
|
| 159 |
prompt,
|
| 160 |
+
negative_prompt,
|
| 161 |
padding="max_length",
|
| 162 |
max_length=self.tokenizer_max_length,
|
| 163 |
truncation=True,
|
|
|
|
| 190 |
self,
|
| 191 |
prompt: Union[str, List[str]],
|
| 192 |
prompt_2: Union[str, List[str]],
|
| 193 |
+
do_classifier_free_guidance: bool = True,
|
| 194 |
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 195 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 196 |
device: Optional[torch.device] = None,
|
| 197 |
num_images_per_prompt: int = 1,
|
| 198 |
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 199 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 200 |
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 201 |
+
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 202 |
max_sequence_length: int = 512,
|
| 203 |
lora_scale: Optional[float] = None,
|
| 204 |
+
adapter_weights: Optional[float] = None,
|
| 205 |
):
|
| 206 |
r"""
|
| 207 |
|
|
|
|
| 238 |
scale_lora_layers(self.text_encoder_2, lora_scale)
|
| 239 |
|
| 240 |
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
|
|
| 241 |
|
| 242 |
if prompt_embeds is None:
|
| 243 |
prompt_2 = prompt_2 or prompt
|
|
|
|
| 256 |
device=device,
|
| 257 |
)
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
|
| 260 |
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
|
| 261 |
|
|
|
|
| 265 |
self,
|
| 266 |
prompt,
|
| 267 |
prompt_2,
|
| 268 |
+
negative_prompt=None,
|
| 269 |
height,
|
| 270 |
width,
|
| 271 |
+
lora_scale=None,
|
| 272 |
prompt_embeds=None,
|
| 273 |
pooled_prompt_embeds=None,
|
| 274 |
callback_on_step_end_tensor_inputs=None,
|
|
|
|
| 307 |
raise ValueError(
|
| 308 |
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
|
| 309 |
)
|
| 310 |
+
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
| 311 |
+
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
| 312 |
|
| 313 |
if max_sequence_length is not None and max_sequence_length > 512:
|
| 314 |
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
|
| 315 |
+
|
| 316 |
+
prompt_attention_mask = text_inputs.attention_mask
|
| 317 |
+
prompt_attention_mask = prompt_attention_mask.to(device)
|
| 318 |
+
|
| 319 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
| 320 |
+
prompt_embeds = prompt_embeds[0]
|
| 321 |
+
|
| 322 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 323 |
+
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
| 324 |
+
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
| 325 |
+
max_length = prompt_embeds.shape[1]
|
| 326 |
+
uncond_input = self.tokenizer(
|
| 327 |
+
uncond_tokens,
|
| 328 |
+
padding="max_length",
|
| 329 |
+
max_length=max_length,
|
| 330 |
+
truncation=True,
|
| 331 |
+
return_attention_mask=True,
|
| 332 |
+
add_special_tokens=True,
|
| 333 |
+
return_tensors="pt",
|
| 334 |
+
)
|
| 335 |
+
negative_prompt_attention_mask = uncond_input.attention_mask
|
| 336 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
| 337 |
+
|
| 338 |
+
negative_prompt_embeds = self.text_encoder(
|
| 339 |
+
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
| 340 |
+
)
|
| 341 |
+
negative_prompt_embeds = negative_prompt_embeds[0]
|
| 342 |
+
|
| 343 |
+
if do_classifier_free_guidance:
|
| 344 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 345 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 346 |
+
|
| 347 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
| 348 |
+
|
| 349 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
| 350 |
+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
| 351 |
+
|
| 352 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
| 353 |
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
| 354 |
+
else:
|
| 355 |
+
negative_prompt_embeds = None
|
| 356 |
+
negative_prompt_attention_mask = None
|
| 357 |
|
| 358 |
+
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
| 359 |
+
|
| 360 |
@staticmethod
|
| 361 |
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 362 |
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
|
|
|
| 479 |
prompt_2: Optional[Union[str, List[str]]] = None,
|
| 480 |
height: Optional[int] = None,
|
| 481 |
width: Optional[int] = None,
|
| 482 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 483 |
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
| 484 |
num_inference_steps: int = 4,
|
| 485 |
timesteps: List[int] = None,
|
|
|
|
| 499 |
):
|
| 500 |
height = height or self.default_sample_size * self.vae_scale_factor
|
| 501 |
width = width or self.default_sample_size * self.vae_scale_factor
|
| 502 |
+
|
| 503 |
# 1. Check inputs
|
| 504 |
self.check_inputs(
|
| 505 |
prompt,
|