Eyalgut commited on
Commit
8baa624
·
verified ·
1 Parent(s): 36c6ec5

Update pipeline_bria.py

Browse files
Files changed (1) hide show
  1. pipeline_bria.py +128 -28
pipeline_bria.py CHANGED
@@ -1,4 +1,4 @@
1
- from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
@@ -25,7 +25,9 @@ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
  from transformer_bria import BriaTransformer2DModel
27
  from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
28
-
 
 
29
  if is_torch_xla_available():
30
  import torch_xla.core.xla_model as xm
31
 
@@ -78,10 +80,6 @@ class BriaPipeline(FluxPipeline):
78
  [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
79
  """
80
 
81
- # model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder->transformer->vae"
82
- # _optional_components = []
83
- # _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
84
-
85
  def __init__(
86
  self,
87
  transformer: BriaTransformer2DModel,
@@ -109,6 +107,11 @@ class BriaPipeline(FluxPipeline):
109
  self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
110
  for block in self.text_encoder.encoder.block:
111
  block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
 
 
 
 
 
112
 
113
  def encode_prompt(
114
  self,
@@ -326,10 +329,10 @@ class BriaPipeline(FluxPipeline):
326
 
327
  Examples:
328
 
329
- Returns:
330
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
331
- [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
332
- `tuple`. When returning a tuple, the first element is a list with the generated images.
333
  """
334
 
335
  height = height or self.default_sample_size * self.vae_scale_factor
@@ -382,16 +385,7 @@ class BriaPipeline(FluxPipeline):
382
  if self.do_classifier_free_guidance:
383
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
384
 
385
- # 4. Prepare timesteps
386
- # Sample from training sigmas
387
- if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
388
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
389
- else:
390
- sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
391
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
392
-
393
- num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
394
- self._num_timesteps = len(timesteps)
395
 
396
  # 5. Prepare latent variables
397
  num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
@@ -406,9 +400,42 @@ class BriaPipeline(FluxPipeline):
406
  latents,
407
  )
408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
  # Supprot different diffusers versions
410
- if len(latent_image_ids.shape)==2:
411
- text_ids=text_ids.squeeze()
 
412
 
413
  # 6. Denoising loop
414
  with self.progress_bar(total=num_inference_steps) as progress_bar:
@@ -452,12 +479,6 @@ class BriaPipeline(FluxPipeline):
452
  latents_dtype = latents.dtype
453
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
454
 
455
-
456
- # if latents.std().item()>2:
457
- # print('Warning')
458
-
459
- # print(t.item(),latents.mean().item(),latents.std().item())
460
-
461
  if latents.dtype != latents_dtype:
462
  if torch.backends.mps.is_available():
463
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
@@ -554,9 +575,88 @@ class BriaPipeline(FluxPipeline):
554
  self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
555
  for block in self.text_encoder.encoder.block:
556
  block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
 
 
 
 
557
 
558
  return self
559
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
 
561
 
562
 
 
1
+ from diffusers.pipelines.flux.pipeline_flux import FluxPipeline, retrieve_timesteps, calculate_shift
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
 
25
  from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
26
  from transformer_bria import BriaTransformer2DModel
27
  from bria_utils import get_t5_prompt_embeds, get_original_sigmas, is_ng_none
28
+ from diffusers.utils.torch_utils import randn_tensor
29
+ import diffusers
30
+ import numpy as np
31
  if is_torch_xla_available():
32
  import torch_xla.core.xla_model as xm
33
 
 
80
  [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
81
  """
82
 
 
 
 
 
83
  def __init__(
84
  self,
85
  transformer: BriaTransformer2DModel,
 
107
  self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
108
  for block in self.text_encoder.encoder.block:
109
  block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
110
+
111
+ if self.vae.config.shift_factor is None:
112
+ self.vae.config.shift_factor=0
113
+ self.vae.to(dtype=torch.float32)
114
+
115
 
116
  def encode_prompt(
117
  self,
 
329
 
330
  Examples:
331
 
332
+ Returns:
333
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
334
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
335
+ images.
336
  """
337
 
338
  height = height or self.default_sample_size * self.vae_scale_factor
 
385
  if self.do_classifier_free_guidance:
386
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
387
 
388
+
 
 
 
 
 
 
 
 
 
389
 
390
  # 5. Prepare latent variables
391
  num_channels_latents = self.transformer.config.in_channels // 4 # due to patch=2, we devide by 4
 
400
  latents,
401
  )
402
 
403
+ if isinstance(self.scheduler,FlowMatchEulerDiscreteScheduler) and self.scheduler.config['use_dynamic_shifting']:
404
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
405
+ image_seq_len = latents.shape[1] # Shift by height - Why just height?
406
+ print(f"Using dynamic shift in pipeline with sequence length {image_seq_len}")
407
+
408
+ mu = calculate_shift(
409
+ image_seq_len,
410
+ self.scheduler.config.base_image_seq_len,
411
+ self.scheduler.config.max_image_seq_len,
412
+ self.scheduler.config.base_shift,
413
+ self.scheduler.config.max_shift,
414
+ )
415
+ timesteps, num_inference_steps = retrieve_timesteps(
416
+ self.scheduler,
417
+ num_inference_steps,
418
+ device,
419
+ timesteps,
420
+ sigmas,
421
+ mu=mu,
422
+ )
423
+ else:
424
+ # 4. Prepare timesteps
425
+ # Sample from training sigmas
426
+ if isinstance(self.scheduler,DDIMScheduler) or isinstance(self.scheduler,EulerAncestralDiscreteScheduler):
427
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, None)
428
+ else:
429
+ sigmas = get_original_sigmas(num_train_timesteps=self.scheduler.config.num_train_timesteps,num_inference_steps=num_inference_steps)
430
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps,sigmas=sigmas)
431
+
432
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
433
+ self._num_timesteps = len(timesteps)
434
+
435
  # Supprot different diffusers versions
436
+ if diffusers.__version__>='0.32.0':
437
+ latent_image_ids=latent_image_ids[0]
438
+ text_ids=text_ids[0]
439
 
440
  # 6. Denoising loop
441
  with self.progress_bar(total=num_inference_steps) as progress_bar:
 
479
  latents_dtype = latents.dtype
480
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
481
 
 
 
 
 
 
 
482
  if latents.dtype != latents_dtype:
483
  if torch.backends.mps.is_available():
484
  # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
 
575
  self.text_encoder = self.text_encoder.to(dtype=T5_PRECISION)
576
  for block in self.text_encoder.encoder.block:
577
  block.layer[-1].DenseReluDense.wo.to(dtype=torch.float32)
578
+
579
+ if self.vae.config.shift_factor == 0 and self.vae.dtype!=torch.float32:
580
+ self.vae.to(dtype=torch.float32)
581
+
582
 
583
  return self
584
 
585
+
586
+ def prepare_latents(
587
+ self,
588
+ batch_size,
589
+ num_channels_latents,
590
+ height,
591
+ width,
592
+ dtype,
593
+ device,
594
+ generator,
595
+ latents=None,
596
+ ):
597
+ # VAE applies 8x compression on images but we must also account for packing which requires
598
+ # latent height and width to be divisible by 2.
599
+ height = 2 * (int(height) // self.vae_scale_factor)
600
+ width = 2 * (int(width) // self.vae_scale_factor )
601
+
602
+ shape = (batch_size, num_channels_latents, height, width)
603
+
604
+ if latents is not None:
605
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
606
+ return latents.to(device=device, dtype=dtype), latent_image_ids
607
+
608
+ if isinstance(generator, list) and len(generator) != batch_size:
609
+ raise ValueError(
610
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
611
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
612
+ )
613
+
614
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
615
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
616
+
617
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
618
+
619
+ return latents, latent_image_ids
620
+
621
+ @staticmethod
622
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
623
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
624
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
625
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
626
+
627
+ return latents
628
+
629
+ @staticmethod
630
+ def _unpack_latents(latents, height, width, vae_scale_factor):
631
+ batch_size, num_patches, channels = latents.shape
632
+
633
+ height = height // vae_scale_factor
634
+ width = width // vae_scale_factor
635
+
636
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
637
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
638
+
639
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
640
+
641
+ return latents
642
+
643
+ @staticmethod
644
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
645
+ latent_image_ids = torch.zeros(height, width, 3)
646
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
647
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
648
+
649
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
650
+
651
+ latent_image_ids = latent_image_ids.repeat(batch_size, 1, 1, 1)
652
+ latent_image_ids = latent_image_ids.reshape(
653
+ batch_size, latent_image_id_height * latent_image_id_width, latent_image_id_channels
654
+ )
655
+
656
+ return latent_image_ids.to(device=device, dtype=dtype)
657
+
658
+
659
+
660
 
661
 
662