| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| from diffusers.models import UNet2DConditionModel | |
| from diffusers.models.attention import Attention | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| def add_imagedream_attn_processor(unet: UNet2DConditionModel) -> nn.Module: | |
| attn_procs = {} | |
| for key, attn_processor in unet.attn_processors.items(): | |
| if "attn1" in key: | |
| attn_procs[key] = ImageDreamAttnProcessor2_0() | |
| else: | |
| attn_procs[key] = attn_processor | |
| unet.set_attn_processor(attn_procs) | |
| return unet | |
| class ImageDreamAttnProcessor2_0(AttnProcessor2_0): | |
| def __call__( | |
| self, | |
| attn: Attention, | |
| hidden_states: torch.Tensor, | |
| encoder_hidden_states: Optional[torch.Tensor] = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| temb: Optional[torch.Tensor] = None, | |
| num_views: int = 1, | |
| *args, | |
| **kwargs, | |
| ): | |
| if num_views == 1: | |
| return super().__call__( | |
| attn=attn, | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| temb=temb, | |
| *args, | |
| **kwargs, | |
| ) | |
| input_ndim = hidden_states.ndim | |
| B = hidden_states.size(0) | |
| if B % num_views: | |
| raise ValueError( | |
| f"`batch_size`(got {B}) must be a multiple of `num_views`(got {num_views})." | |
| ) | |
| real_B = B // num_views | |
| if input_ndim == 4: | |
| H, W = hidden_states.shape[2:] | |
| hidden_states = hidden_states.reshape(real_B, -1, H, W).transpose(1, 2) | |
| else: | |
| hidden_states = hidden_states.reshape(real_B, -1, hidden_states.size(-1)) | |
| hidden_states = super().__call__( | |
| attn=attn, | |
| hidden_states=hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| attention_mask=attention_mask, | |
| temb=temb, | |
| *args, | |
| **kwargs, | |
| ) | |
| if input_ndim == 4: | |
| hidden_states = hidden_states.transpose(-1, -2).reshape(B, -1, H, W) | |
| else: | |
| hidden_states = hidden_states.reshape(B, -1, hidden_states.size(-1)) | |
| return hidden_states | |