Spaces:
Running
Running
Update pipeline_bria.py
Browse files- pipeline_bria.py +128 -28
pipeline_bria.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps
|
| 2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
|
| 4 |
import torch
|
|
@@ -25,7 +25,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
|
| 25 |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 26 |
from transformer_bria import BriaTransformer2DModel
|
| 27 |
from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
if is_torch_xla_available():
|
| 30 |
import torch_xla.core.xla_model as xm
|
| 31 |
|
|
@@ -78,10 +80,6 @@ class BriaPipeline(FluxPipeline):
|
|
| 78 |
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 79 |
"""
|
| 80 |
|
| 81 |
-
# model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
|
| 82 |
-
# _optional_components = []
|
| 83 |
-
# _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
|
| 84 |
-
|
| 85 |
def __init__(
|
| 86 |
self,
|
| 87 |
transformer: BriaTransformer2DModel,
|
|
@@ -109,6 +107,11 @@ class BriaPipeline(FluxPipeline):
|
|
| 109 |
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
|
| 110 |
for block in self.text_encoder.encoder.block:
|
| 111 |
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def encode_prompt(
|
| 114 |
self,
|
|
@@ -326,10 +329,10 @@ class BriaPipeline(FluxPipeline):
|
|
| 326 |
|
| 327 |
Examples:
|
| 328 |
|
| 329 |
-
|
| 330 |
-
[`~pipelines.
|
| 331 |
-
|
| 332 |
-
|
| 333 |
"""
|
| 334 |
|
| 335 |
height = height or self.default_sample_size * self.vae_scale_factor
|
|
@@ -382,16 +385,7 @@ class BriaPipeline(FluxPipeline):
|
|
| 382 |
if self.do_classifier_free_guidance:
|
| 383 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 384 |
|
| 385 |
-
|
| 386 |
-
# Sample from training sigmas
|
| 387 |
-
if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
|
| 388 |
-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
|
| 389 |
-
else:
|
| 390 |
-
sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
|
| 391 |
-
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
|
| 392 |
-
|
| 393 |
-
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 394 |
-
self._num_timesteps = len(timesteps)
|
| 395 |
|
| 396 |
# 5. Prepare latent variables
|
| 397 |
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
|
|
@@ -406,9 +400,42 @@ class BriaPipeline(FluxPipeline):
|
|
| 406 |
latents,
|
| 407 |
)
|
| 408 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
# Supprot different diffusers versions
|
| 410 |
-
if
|
| 411 |
-
|
|
|
|
| 412 |
|
| 413 |
# 6. Denoising loop
|
| 414 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
@@ -452,12 +479,6 @@ class BriaPipeline(FluxPipeline):
|
|
| 452 |
latents_dtype = latents.dtype
|
| 453 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 454 |
|
| 455 |
-
|
| 456 |
-
# if latents.std().item()>2:
|
| 457 |
-
# print('Warning')
|
| 458 |
-
|
| 459 |
-
# print(t.item(),latents.mean().item(),latents.std().item())
|
| 460 |
-
|
| 461 |
if latents.dtype != latents_dtype:
|
| 462 |
if torch.backends.mps.is_available():
|
| 463 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
@@ -554,9 +575,88 @@ class BriaPipeline(FluxPipeline):
|
|
| 554 |
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
|
| 555 |
for block in self.text_encoder.encoder.block:
|
| 556 |
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 557 |
|
| 558 |
return self
|
| 559 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 560 |
|
| 561 |
|
| 562 |
|
|
|
|
| 1 |
+
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps, calculate_shift
|
| 2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
| 3 |
|
| 4 |
import torch
|
|
|
|
| 25 |
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
|
| 26 |
from transformer_bria import BriaTransformer2DModel
|
| 27 |
from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
|
| 28 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 29 |
+
import diffusers
|
| 30 |
+
import numpy as np
|
| 31 |
if is_torch_xla_available():
|
| 32 |
import torch_xla.core.xla_model as xm
|
| 33 |
|
|
|
|
| 80 |
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
|
| 81 |
"""
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
def __init__(
|
| 84 |
self,
|
| 85 |
transformer: BriaTransformer2DModel,
|
|
|
|
| 107 |
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
|
| 108 |
for block in self.text_encoder.encoder.block:
|
| 109 |
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
| 110 |
+
|
| 111 |
+
if self.vae.config.shift_factor is None:
|
| 112 |
+
self.vae.config.shift_factor=0
|
| 113 |
+
self.vae.to(dtype=torch.float32)
|
| 114 |
+
|
| 115 |
|
| 116 |
def encode_prompt(
|
| 117 |
self,
|
|
|
|
| 329 |
|
| 330 |
Examples:
|
| 331 |
|
| 332 |
+
Returns:
|
| 333 |
+
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
|
| 334 |
+
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
|
| 335 |
+
images.
|
| 336 |
"""
|
| 337 |
|
| 338 |
height = height or self.default_sample_size * self.vae_scale_factor
|
|
|
|
| 385 |
if self.do_classifier_free_guidance:
|
| 386 |
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 387 |
|
| 388 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
|
| 390 |
# 5. Prepare latent variables
|
| 391 |
num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
|
|
|
|
| 400 |
latents,
|
| 401 |
)
|
| 402 |
|
| 403 |
+
if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
|
| 404 |
+
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 405 |
+
image_seq_len = latents.shape[1] # Shift by height - Why just height?
|
| 406 |
+
print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
|
| 407 |
+
|
| 408 |
+
mu = calculate_shift(
|
| 409 |
+
image_seq_len,
|
| 410 |
+
self.scheduler.config.base_image_seq_len,
|
| 411 |
+
self.scheduler.config.max_image_seq_len,
|
| 412 |
+
self.scheduler.config.base_shift,
|
| 413 |
+
self.scheduler.config.max_shift,
|
| 414 |
+
)
|
| 415 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 416 |
+
self.scheduler,
|
| 417 |
+
num_inference_steps,
|
| 418 |
+
device,
|
| 419 |
+
timesteps,
|
| 420 |
+
sigmas,
|
| 421 |
+
mu=mu,
|
| 422 |
+
)
|
| 423 |
+
else:
|
| 424 |
+
# 4. Prepare timesteps
|
| 425 |
+
# Sample from training sigmas
|
| 426 |
+
if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
|
| 427 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
|
| 428 |
+
else:
|
| 429 |
+
sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
|
| 430 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
|
| 431 |
+
|
| 432 |
+
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
| 433 |
+
self._num_timesteps = len(timesteps)
|
| 434 |
+
|
| 435 |
# Supprot different diffusers versions
|
| 436 |
+
if diffusers.__version__>='0.32.0':
|
| 437 |
+
latent_image_ids=latent_image_ids[0]
|
| 438 |
+
text_ids=text_ids[0]
|
| 439 |
|
| 440 |
# 6. Denoising loop
|
| 441 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
|
|
| 479 |
latents_dtype = latents.dtype
|
| 480 |
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 481 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 482 |
if latents.dtype != latents_dtype:
|
| 483 |
if torch.backends.mps.is_available():
|
| 484 |
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
|
|
|
| 575 |
self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
|
| 576 |
for block in self.text_encoder.encoder.block:
|
| 577 |
block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
|
| 578 |
+
|
| 579 |
+
if self.vae.config.shift_factor == 0 and self.vae.dtype!=torch.float32:
|
| 580 |
+
self.vae.to(dtype=torch.float32)
|
| 581 |
+
|
| 582 |
|
| 583 |
return self
|
| 584 |
|
| 585 |
+
|
| 586 |
+
def prepare_latents(
|
| 587 |
+
self,
|
| 588 |
+
batch_size,
|
| 589 |
+
num_channels_latents,
|
| 590 |
+
height,
|
| 591 |
+
width,
|
| 592 |
+
dtype,
|
| 593 |
+
device,
|
| 594 |
+
generator,
|
| 595 |
+
latents=None,
|
| 596 |
+
):
|
| 597 |
+
# VAE applies 8x compression on images but we must also account for packing which requires
|
| 598 |
+
# latent height and width to be divisible by 2.
|
| 599 |
+
height = 2 * (int(height) // self.vae_scale_factor)
|
| 600 |
+
width = 2 * (int(width) // self.vae_scale_factor )
|
| 601 |
+
|
| 602 |
+
shape = (batch_size, num_channels_latents, height, width)
|
| 603 |
+
|
| 604 |
+
if latents is not None:
|
| 605 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 606 |
+
return latents.to(device=device, dtype=dtype), latent_image_ids
|
| 607 |
+
|
| 608 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 609 |
+
raise ValueError(
|
| 610 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 611 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 615 |
+
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
|
| 616 |
+
|
| 617 |
+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
|
| 618 |
+
|
| 619 |
+
return latents, latent_image_ids
|
| 620 |
+
|
| 621 |
+
@staticmethod
|
| 622 |
+
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
|
| 623 |
+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
| 624 |
+
latents = latents.permute(0, 2, 4, 1, 3, 5)
|
| 625 |
+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
|
| 626 |
+
|
| 627 |
+
return latents
|
| 628 |
+
|
| 629 |
+
@staticmethod
|
| 630 |
+
def _unpack_latents(latents, height, width, vae_scale_factor):
|
| 631 |
+
batch_size, num_patches, channels = latents.shape
|
| 632 |
+
|
| 633 |
+
height = height // vae_scale_factor
|
| 634 |
+
width = width // vae_scale_factor
|
| 635 |
+
|
| 636 |
+
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
|
| 637 |
+
latents = latents.permute(0, 3, 1, 4, 2, 5)
|
| 638 |
+
|
| 639 |
+
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
|
| 640 |
+
|
| 641 |
+
return latents
|
| 642 |
+
|
| 643 |
+
@staticmethod
|
| 644 |
+
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
| 645 |
+
latent_image_ids = torch.zeros(height, width, 3)
|
| 646 |
+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
|
| 647 |
+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
|
| 648 |
+
|
| 649 |
+
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
| 650 |
+
|
| 651 |
+
latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
|
| 652 |
+
latent_image_ids = latent_image_ids.reshape(
|
| 653 |
+
batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
| 654 |
+
)
|
| 655 |
+
|
| 656 |
+
return latent_image_ids.to(device=device, dtype=dtype)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
|
| 660 |
|
| 661 |
|
| 662 |
|