manbeast3b commited on
Commit
b274a91
·
verified ·
1 Parent(s): 9d52b5e

Update src/utils.py

Browse files
Files changed (1) hide show
  1. src/utils.py +34 -7
src/utils.py CHANGED
@@ -600,8 +600,10 @@ def register_faster_forward(model, mod = '50ls'):
600
  timestep_cond: Optional[torch.Tensor] = None,
601
  attention_mask: Optional[torch.Tensor] = None,
602
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
 
603
  down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
604
  mid_block_additional_residual: Optional[torch.Tensor] = None,
 
605
  return_dict: bool = True,
606
  ) -> Union[UNet2DConditionOutput, Tuple]:
607
  r"""
@@ -739,18 +741,27 @@ def register_faster_forward(model, mod = '50ls'):
739
  down_block_res_samples = (sample,)
740
  for downsample_block in self.down_blocks:
741
  if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
 
 
 
 
742
  sample, res_samples = downsample_block(
743
  hidden_states=sample,
744
  temb=emb,
745
  encoder_hidden_states=encoder_hidden_states,
746
  attention_mask=attention_mask,
747
  cross_attention_kwargs=cross_attention_kwargs,
 
748
  )
749
  else:
750
  sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
 
 
751
 
752
  down_block_res_samples += res_samples
753
 
 
 
754
  if down_block_additional_residuals is not None:
755
  new_down_block_res_samples = ()
756
 
@@ -762,15 +773,31 @@ def register_faster_forward(model, mod = '50ls'):
762
 
763
  down_block_res_samples = new_down_block_res_samples
764
 
 
 
 
 
 
 
 
 
 
 
765
  # 4. mid
766
  if self.mid_block is not None:
767
- sample = self.mid_block(
768
- sample,
769
- emb,
770
- encoder_hidden_states=encoder_hidden_states,
771
- attention_mask=attention_mask,
772
- cross_attention_kwargs=cross_attention_kwargs,
773
- )
 
 
 
 
 
 
774
 
775
  if mid_block_additional_residual is not None:
776
  sample = sample + mid_block_additional_residual
 
600
  timestep_cond: Optional[torch.Tensor] = None,
601
  attention_mask: Optional[torch.Tensor] = None,
602
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
603
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, # ADDED
604
  down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
605
  mid_block_additional_residual: Optional[torch.Tensor] = None,
606
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # ADDED
607
  return_dict: bool = True,
608
  ) -> Union[UNet2DConditionOutput, Tuple]:
609
  r"""
 
741
  down_block_res_samples = (sample,)
742
  for downsample_block in self.down_blocks:
743
  if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
744
+ #added for t2i adapters
745
+ additional_residuals = {}
746
+ if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0:
747
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
748
  sample, res_samples = downsample_block(
749
  hidden_states=sample,
750
  temb=emb,
751
  encoder_hidden_states=encoder_hidden_states,
752
  attention_mask=attention_mask,
753
  cross_attention_kwargs=cross_attention_kwargs,
754
+ **additional_residuals
755
  )
756
  else:
757
  sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
758
+ if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0:
759
+ sample += down_intrablock_additional_residuals.pop(0)
760
 
761
  down_block_res_samples += res_samples
762
 
763
+
764
+
765
  if down_block_additional_residuals is not None:
766
  new_down_block_res_samples = ()
767
 
 
773
 
774
  down_block_res_samples = new_down_block_res_samples
775
 
776
+ # Handle ControlNet additional residuals
777
+ if down_block_additional_residuals is not None:
778
+ new_down_block_res_samples = ()
779
+ for down_block_res_sample, down_block_additional_residual in zip(
780
+ down_block_res_samples, down_block_additional_residuals
781
+ ):
782
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
783
+ new_down_block_res_samples += (down_block_res_sample,)
784
+ down_block_res_samples = new_down_block_res_samples
785
+
786
  # 4. mid
787
  if self.mid_block is not None:
788
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
789
+ sample = self.mid_block(
790
+ sample,
791
+ emb,
792
+ encoder_hidden_states=encoder_hidden_states,
793
+ attention_mask=attention_mask,
794
+ cross_attention_kwargs=cross_attention_kwargs,
795
+ )
796
+ else:
797
+ sample = self.mid_block(sample, emb)
798
+ #Handle T2I-Adapter-XL
799
+ if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape:
800
+ sample += down_intrablock_additional_residuals.pop(0)
801
 
802
  if mid_block_additional_residual is not None:
803
  sample = sample + mid_block_additional_residual