Update src/utils.py
Browse files- src/utils.py +34 -7
src/utils.py
CHANGED
|
@@ -600,8 +600,10 @@ def register_faster_forward(model, mod = '50ls'):
|
|
| 600 |
timestep_cond: Optional[torch.Tensor] = None,
|
| 601 |
attention_mask: Optional[torch.Tensor] = None,
|
| 602 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
| 603 |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 604 |
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
|
|
|
| 605 |
return_dict: bool = True,
|
| 606 |
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 607 |
r"""
|
|
@@ -739,18 +741,27 @@ def register_faster_forward(model, mod = '50ls'):
|
|
| 739 |
down_block_res_samples = (sample,)
|
| 740 |
for downsample_block in self.down_blocks:
|
| 741 |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 742 |
sample, res_samples = downsample_block(
|
| 743 |
hidden_states=sample,
|
| 744 |
temb=emb,
|
| 745 |
encoder_hidden_states=encoder_hidden_states,
|
| 746 |
attention_mask=attention_mask,
|
| 747 |
cross_attention_kwargs=cross_attention_kwargs,
|
|
|
|
| 748 |
)
|
| 749 |
else:
|
| 750 |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
|
|
|
|
|
|
| 751 |
|
| 752 |
down_block_res_samples += res_samples
|
| 753 |
|
|
|
|
|
|
|
| 754 |
if down_block_additional_residuals is not None:
|
| 755 |
new_down_block_res_samples = ()
|
| 756 |
|
|
@@ -762,15 +773,31 @@ def register_faster_forward(model, mod = '50ls'):
|
|
| 762 |
|
| 763 |
down_block_res_samples = new_down_block_res_samples
|
| 764 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 765 |
# 4. mid
|
| 766 |
if self.mid_block is not None:
|
| 767 |
-
|
| 768 |
-
sample
|
| 769 |
-
|
| 770 |
-
|
| 771 |
-
|
| 772 |
-
|
| 773 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 774 |
|
| 775 |
if mid_block_additional_residual is not None:
|
| 776 |
sample = sample + mid_block_additional_residual
|
|
|
|
| 600 |
timestep_cond: Optional[torch.Tensor] = None,
|
| 601 |
attention_mask: Optional[torch.Tensor] = None,
|
| 602 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 603 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, # ADDED
|
| 604 |
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
| 605 |
mid_block_additional_residual: Optional[torch.Tensor] = None,
|
| 606 |
+
down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, # ADDED
|
| 607 |
return_dict: bool = True,
|
| 608 |
) -> Union[UNet2DConditionOutput, Tuple]:
|
| 609 |
r"""
|
|
|
|
| 741 |
down_block_res_samples = (sample,)
|
| 742 |
for downsample_block in self.down_blocks:
|
| 743 |
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 744 |
+
#added for t2i adapters
|
| 745 |
+
additional_residuals = {}
|
| 746 |
+
if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0:
|
| 747 |
+
additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
|
| 748 |
sample, res_samples = downsample_block(
|
| 749 |
hidden_states=sample,
|
| 750 |
temb=emb,
|
| 751 |
encoder_hidden_states=encoder_hidden_states,
|
| 752 |
attention_mask=attention_mask,
|
| 753 |
cross_attention_kwargs=cross_attention_kwargs,
|
| 754 |
+
**additional_residuals
|
| 755 |
)
|
| 756 |
else:
|
| 757 |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
| 758 |
+
if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0:
|
| 759 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 760 |
|
| 761 |
down_block_res_samples += res_samples
|
| 762 |
|
| 763 |
+
|
| 764 |
+
|
| 765 |
if down_block_additional_residuals is not None:
|
| 766 |
new_down_block_res_samples = ()
|
| 767 |
|
|
|
|
| 773 |
|
| 774 |
down_block_res_samples = new_down_block_res_samples
|
| 775 |
|
| 776 |
+
# Handle ControlNet additional residuals
|
| 777 |
+
if down_block_additional_residuals is not None:
|
| 778 |
+
new_down_block_res_samples = ()
|
| 779 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 780 |
+
down_block_res_samples, down_block_additional_residuals
|
| 781 |
+
):
|
| 782 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 783 |
+
new_down_block_res_samples += (down_block_res_sample,)
|
| 784 |
+
down_block_res_samples = new_down_block_res_samples
|
| 785 |
+
|
| 786 |
# 4. mid
|
| 787 |
if self.mid_block is not None:
|
| 788 |
+
if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
|
| 789 |
+
sample = self.mid_block(
|
| 790 |
+
sample,
|
| 791 |
+
emb,
|
| 792 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 793 |
+
attention_mask=attention_mask,
|
| 794 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 795 |
+
)
|
| 796 |
+
else:
|
| 797 |
+
sample = self.mid_block(sample, emb)
|
| 798 |
+
#Handle T2I-Adapter-XL
|
| 799 |
+
if down_intrablock_additional_residuals is not None and len(down_intrablock_additional_residuals) > 0 and sample.shape == down_intrablock_additional_residuals[0].shape:
|
| 800 |
+
sample += down_intrablock_additional_residuals.pop(0)
|
| 801 |
|
| 802 |
if mid_block_additional_residual is not None:
|
| 803 |
sample = sample + mid_block_additional_residual
|