# Copyright 2023 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import math from torch import nn from torch.nn import functional as F import einops from timm.models.layers import Mlp from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.utils import BaseOutput, logging from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.models.modeling_utils import ModelMixin from diffusers.models.resnet import Downsample2D, ResnetBlock2D from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D from diffusers.models.upsampling import Upsample2D from diffusers.utils import deprecate, is_torch_version from einops import rearrange from diffusers.models.attention import Attention,FeedForward from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed from diffusers.models.embeddings import get_2d_sincos_pos_embed,get_1d_sincos_pos_embed_from_grid,get_3d_sincos_pos_embed,TimestepEmbedding, Timesteps VIS_ATTEN_FLAG = False attention_maps = [] def get_attention_maps(): global attention_maps return attention_maps def clear_attention_maps(): global attention_maps attention_maps.clear() def set_vis_atten_flag(flag): global VIS_ATTEN_FLAG VIS_ATTEN_FLAG = flag logger = logging.get_logger(__name__) # pylint: disable=invalid-name # ----------------------- A2M Type1 Predict Pose ------------------------ class DownEncoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_downsample: bool = True, downsample_padding: int = 1, ): super().__init__() resnets = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=None, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_downsample: self.downsamplers = nn.ModuleList( [ Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] ) else: self.downsamplers = None def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=None) if self.downsamplers is not None: for downsampler in self.downsamplers: hidden_states = downsampler(hidden_states) return hidden_states class MidBlock2D(nn.Module): def __init__( self, in_channel: int = 64, out_channel: int = 1280, ): super().__init__() self.mid_convs = nn.ModuleList() self.mid_convs.append(nn.Sequential( nn.Conv2d( in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1, padding=1 ), nn.ReLU(), nn.Conv2d( in_channels=in_channel, out_channels=in_channel, kernel_size=3, stride=1, padding=1 ), )) self.mid_convs.append(nn.Conv2d( in_channels=in_channel, out_channels=out_channel, kernel_size=1, stride=1, )) def forward(self, x): for mid_conv in self.mid_convs: sample = mid_conv(x) return sample class UpDecoderBlock2D(nn.Module): def __init__( self, in_channels: int, out_channels: int, resolution_idx: Optional[int] = None, dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", # default, spatial resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, output_scale_factor: float = 1.0, add_upsample: bool = True, temb_channels: Optional[int] = None, ): super().__init__() resnets = [] for i in range(num_layers): input_channels = in_channels if i == 0 else out_channels resnets.append( ResnetBlock2D( in_channels=input_channels, out_channels=out_channels, temb_channels=temb_channels, eps=resnet_eps, groups=resnet_groups, dropout=dropout, time_embedding_norm=resnet_time_scale_shift, non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, ) ) self.resnets = nn.ModuleList(resnets) if add_upsample: self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.resolution_idx = resolution_idx def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: for resnet in self.resnets: hidden_states = resnet(hidden_states, temb=temb) if self.upsamplers is not None: for upsampler in self.upsamplers: hidden_states = upsampler(hidden_states) return hidden_states class DuoFrameDownEncoder(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channel: int = 4, block_out_channels : Tuple[int] = (64, 128, 256, 256), norm_groups : int = 4, resnet_layers_per_block: int = 2, add_attention : bool = True, ): super().__init__() # conv_in self.conv_in = nn.Conv2d( in_channel, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) # downblock self.downblock = nn.ModuleList() output_channel = block_out_channels[0] for i,channels in enumerate(block_out_channels): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 self.downblock.append( DownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers= resnet_layers_per_block, resnet_groups = norm_groups, add_downsample=not is_final_block, ) ) # mid_block self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_groups, temb_channels=None, add_attention=add_attention, ) # conv_out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward(self, x: torch.FloatTensor) -> torch.Tensor: """ Args: * x : (b,c,h,w) Output: * x : (b,c',h/8,w/8) """ # conv_in x = self.conv_in(x) # downblock for downblock in self.downblock: x = downblock(x) # mid x = self.mid_block(x) # out x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) return x class MotionDownEncoder(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channel: int = 4, block_out_channels : Tuple[int] = (64, 128, 256, 256), norm_groups : int = 32, resnet_layers_per_block: int = 2, add_attention : bool = True, ): super().__init__() # conv_in self.conv_in = nn.Conv2d( in_channel, block_out_channels[0], kernel_size=1, stride=1, ) # downblock self.downblock = nn.ModuleList() output_channel = block_out_channels[0] for i,channels in enumerate(block_out_channels): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 self.downblock.append( DownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers= resnet_layers_per_block, resnet_groups = norm_groups, add_downsample=not is_final_block, ) ) # mid_block self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_groups, temb_channels=None, add_attention=add_attention, ) # conv_out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward(self, x: torch.FloatTensor) -> torch.Tensor: """ Args: * x : (b,c,h,w) Output: * x : (b,c',h/8,w/8) """ # conv_in x = self.conv_in(x) # downblock for downblock in self.downblock: x = downblock(x) # mid x = self.mid_block(x) # out x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) return x class DownEncoder(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channel: int = 4, block_out_channels : Tuple[int] = (64, 128, 256, 256), norm_groups : int = 8, resnet_layers_per_block: int = 2, add_attention : bool = True, ): super().__init__() # conv_in self.conv_in = nn.Conv2d( in_channel, block_out_channels[0], kernel_size=1, stride=1, ) # downblock self.downblock = nn.ModuleList() output_channel = block_out_channels[0] for i,channels in enumerate(block_out_channels): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 self.downblock.append( DownEncoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers= resnet_layers_per_block, resnet_groups = norm_groups, add_downsample=not is_final_block, ) ) # mid_block self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], resnet_eps=1e-6, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[-1], resnet_groups=norm_groups, temb_channels=None, add_attention=add_attention, ) # conv_out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward(self, x: torch.FloatTensor) -> torch.Tensor: """ Args: * x : (b,c,h,w) Output: * x : (b,c',h/8,w/8) """ # conv_in x = self.conv_in(x) # downblock for downblock in self.downblock: x = downblock(x) # mid x = self.mid_block(x) # out x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) return x class Upsampler(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, in_channel: int = 256, out_channel: Optional[int] = None, block_out_channels : Tuple[int] = (256, 256, 128, 64), norm_groups : int = 8, resnet_layers_per_block: int = 2, add_attention : bool = True, ): super().__init__() self.out_channel = out_channel # conv_in self.conv_in = nn.Conv2d( in_channel, block_out_channels[0], kernel_size=3, stride=1, padding=1, ) # mid_block self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[0], resnet_eps=1e-6, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=block_out_channels[0], resnet_groups=norm_groups, temb_channels=None, add_attention=add_attention, ) # upblock self.upblock = nn.ModuleList() output_channel = block_out_channels[0] for i,channels in enumerate(block_out_channels): input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 self.upblock.append( UpDecoderBlock2D( in_channels=input_channel, out_channels=output_channel, num_layers= resnet_layers_per_block, resnet_groups = norm_groups, add_upsample=not is_final_block, ) ) # conv_out self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) self.conv_act = nn.SiLU() self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) # channel if self.out_channel: self.conv_final = nn.Conv2d( block_out_channels[-1], out_channel, kernel_size=3, stride=1, padding=1, ) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward(self, x: torch.FloatTensor) -> torch.tensor: """ Args: * x : (b,c,h,w) Output: * x : (b,c',h*8,w*8) """ # conv_in x = self.conv_in(x) # mid x = self.mid_block(x) # upblock for upblock in self.upblock: x = upblock(x) # out x = self.conv_norm_out(x) x = self.conv_act(x) x = self.conv_out(x) # final if self.out_channel: x = self.conv_final(x) return x # mapping for same shape & different channels class MapConv(nn.Module): def __init__(self, in_channel: int = 8, hidden : int = 640, out_channel: int = 4, block_layer : int = 8, goups : int = 2,): super().__init__() # conv_in self.conv_in = nn.Conv2d( in_channel, hidden, kernel_size=3, stride=1, padding=1, ) # attn self.mid_block = UNetMidBlock2D( in_channels=hidden, resnet_eps=1e-6, output_scale_factor=1, resnet_time_scale_shift="default", attention_head_dim=64, resnet_groups=goups, temb_channels=None, add_attention=True, ) # map self.map = nn.ModuleList() for i in range(block_layer): resnet = ResnetBlock2D( in_channels=hidden, out_channels=hidden, temb_channels=None, groups=goups, ) self.map.append(resnet) # conv_out self.conv_out = nn.Conv2d( hidden, out_channel, kernel_size=3, stride=1, padding=1, ) def forward(self, x: torch.tensor , temb: Optional[torch.tensor] = None) -> torch.tensor: x = self.conv_in(x) x = self.mid_block(x) for l in self.map: x = l(x,None) x = self.conv_out(x) return x def simple_attention_processor( attn: Attention, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, ): residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, None) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) if attention_mask is not None: attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) # scaled_dot_product_attention expects attention_mask shape to be # (batch, heads, source_length, target_length) attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) if attn.norm_q is not None: query = attn.norm_q(query) if attn.norm_k is not None: key = attn.norm_k(key) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 score = torch.einsum("bhld,bhsd->bhls", query, key) / math.sqrt(head_dim) return score.softmax(dim=-1) class BasicTransformerBlock(nn.Module): r""" Parameters: dim (`int`): The number of channels in the input and output. num_attention_heads (`int`): The number of heads to use for multi-head attention. attention_head_dim (`int`): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. attention_bias (: obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. qk_norm (`bool`, defaults to `True`): Whether or not to use normalization after query and key projections in Attention. norm_elementwise_affine (`bool`, defaults to `True`): Whether to use learnable elementwise affine parameters for normalization. norm_eps (`float`, defaults to `1e-5`): Epsilon value for normalization layers. final_dropout (`bool` defaults to `False`): Whether to apply a final dropout after the last feed-forward layer. ff_inner_dim (`int`, *optional*, defaults to `None`): Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. ff_bias (`bool`, defaults to `True`): Whether or not to use bias in Feed-forward layer. attention_out_bias (`bool`, defaults to `True`): Whether or not to use bias in Attention output projection layer. """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: Optional[int] = None, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: # norm1 norm_hidden_states = self.norm1(hidden_states) # attention attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) # if VIS_ATTEN_FLAG: # global attention_maps # attn_score = simple_attention_processor(self.attn1, hidden_states) # attention_maps.append(attn_score.detach().cpu()) hidden_states = hidden_states + attn_output # norm & modulate norm_hidden_states = self.norm2(hidden_states) # feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + ff_output return hidden_states # patch embed without positional encoding class PatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, in_channels: int = 16, embed_dim: int = 1920, bias: bool = True, ) -> None: super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) def forward(self, image_embeds: torch.Tensor): r""" Args: image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width) or (batch_size, channels, height, width) Returns: embeds (`torch.Tensor`): (batch_size,num_frames x height x width,embed_dim) or (batch_size,1 x height x width,embed_dim) """ if image_embeds.dim() == 5: batch, num_frames, channels, height, width = image_embeds.shape image_embeds = image_embeds.reshape(-1, channels, height, width) else: batch, channels, height, width = image_embeds.shape num_frames = 1 image_embeds = self.proj(image_embeds) image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] return image_embeds # [batch, num_frames x height x width, channels] class AMDLayerNormZero(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.embed_dim = embedding_dim self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] class AMDLayerNormZero_OneVariable(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.embed_dim = embedding_dim self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] return hidden_states, gate[:, None, :] class AMDLayerNormZero2Condition(nn.Module): def __init__( self, conditioning_dim: int, embedding_dim: int, elementwise_affine: bool = True, eps: float = 1e-5, bias: bool = True, ) -> None: super().__init__() self.embed_dim = embedding_dim self.silu = nn.SiLU() self.linear = nn.Linear(conditioning_dim, 9 * embedding_dim, bias=bias) self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) def forward( self, hidden_states: torch.Tensor, condition_states1: torch.Tensor,condition_states2:torch.Tensor, temb: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: shift, scale, gate, c1_shift, c1_scale, c1_gate,c2_shift, c2_scale, c2_gate = self.linear(self.silu(temb)).chunk(9, dim=1) hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] condition_states1 = self.norm(condition_states1) * (1 + c1_scale)[:, None, :] + c1_shift[:, None, :] condition_states2 = self.norm(condition_states2) * (1 + c2_scale)[:, None, :] + c2_shift[:, None, :] return hidden_states, condition_states1,condition_states2, gate[:, None, :], c1_gate[:, None, :],c2_gate[:, None, :] class AdaLayerNorm(nn.Module): r""" Norm layer modified to incorporate timestep embeddings. Parameters: embedding_dim (`int`): The size of each embedding vector. num_embeddings (`int`, *optional*): The size of the embeddings dictionary. output_dim (`int`, *optional*): norm_elementwise_affine (`bool`, defaults to `False): norm_eps (`bool`, defaults to `False`): chunk_dim (`int`, defaults to `0`): """ def __init__( self, embedding_dim: int, num_embeddings: Optional[int] = None, output_dim: Optional[int] = None, norm_elementwise_affine: bool = False, norm_eps: float = 1e-5, chunk_dim: int = 0, ): super().__init__() self.chunk_dim = chunk_dim output_dim = output_dim or embedding_dim * 2 if num_embeddings is not None: self.emb = nn.Embedding(num_embeddings, embedding_dim) else: self.emb = None self.silu = nn.SiLU() self.linear = nn.Linear(embedding_dim, output_dim) self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) def forward( self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None ) -> torch.Tensor: # temb (8*15,512) x (8*15,16,256) if self.emb is not None: temb = self.emb(timestep) temb = self.linear(self.silu(temb)) if self.chunk_dim == 1: # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the # other if-branch. This branch is specific to CogVideoX for now. shift, scale = temb.chunk(2, dim=1) shift = shift[:, None, :] scale = scale[:, None, :] else: scale, shift = temb.chunk(2, dim=0) x = self.norm(x) * (1 + scale) + shift return x class AMDTransformerBlock(nn.Module): r""" AMDTransformerBlock """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: """ ************************* ****************** * encoder_hidden_states * * hidden_states * ************************* ****************** """ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) # attention image_length = norm_encoder_hidden_states.shape[1] # AMD uses concatenated image + motion embeddings with self-attention instead of using # them in cross-attention individually # print(norm_encoder_hidden_states.shape) # print(norm_hidden_states.shape) norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) if VIS_ATTEN_FLAG: global attention_maps attn_score = simple_attention_processor(self.attn1, norm_hidden_states) attention_maps.append(attn_score.detach().cpu()) hidden_states = hidden_states + gate_msa * attn_output[:, image_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :image_length] # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output[:, image_length:] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :image_length] return hidden_states, encoder_hidden_states class BasicDiTBlock(nn.Module): r""" AMDTransformerBlock """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: norm_hidden_states, gate_msa = self.norm1( hidden_states, temb ) # N,F,D attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + gate_msa * attn_output # norm & modulate norm_hidden_states, gate_ff = self.norm2( hidden_states, temb ) # feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output return hidden_states class AMDTransformerMotionBlock(nn.Module): r""" AMDTransformerBlock """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: norm_hidden_states, gate_msa = self.norm1( hidden_states, temb ) # N,F,D attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + gate_msa * attn_output # norm & modulate norm_hidden_states, gate_ff = self.norm2( hidden_states, temb ) # feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output return hidden_states class TransformerBlock2Condition(nn.Module): r""" AMDTransformerBlock """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = AMDLayerNormZero2Condition(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = AMDLayerNormZero2Condition(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, condition_states1: torch.Tensor, condition_states2: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: """ ************************* ****************** ******************** * hidden_states * *condition_states1* * condition_states2* ************************* ****************** ******************** """ hidden_length = hidden_states.shape[1] condition1_length = condition_states1.shape[1] condition2_length = condition_states2.shape[1] norm_hidden_states, norm_condition_states1,norm_condition_states2, gate_msa, c_gate_msa1,c_gate_msa2 = self.norm1( hidden_states, condition_states1,condition_states2, temb ) # AMD uses concatenated image + motion embeddings with self-attention instead of using # them in cross-attention individually norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + gate_msa * attn_output[:,:hidden_length] condition_states1 = condition_states1 + c_gate_msa1 * attn_output[:, hidden_length:hidden_length+condition1_length] condition_states2 = condition_states2 + c_gate_msa2 * attn_output[:, hidden_length+condition1_length:] # norm & modulate norm_hidden_states, norm_condition_states1,norm_condition_states2, gate_ff, c_gate_ff1,c_gate_ff2 = self.norm2( hidden_states, condition_states1,condition_states2, temb ) # feed-forward norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output[:, :hidden_length] condition_states1 = condition_states1 + c_gate_ff1 * ff_output[:, hidden_length:hidden_length+condition1_length] condition_states2 = condition_states2 + c_gate_ff2 * ff_output[:, hidden_length+condition1_length:] return hidden_states, condition_states1,condition_states2 class TransformerBlock2Condition_SimpleAdaLN(nn.Module): r""" TransformerBlock2Condition_SimpleAdaLN """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.norm1_condition1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.norm1_condition2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.norm2_condition1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.norm2_condition2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, condition_states1: torch.Tensor, condition_states2: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: """ ************************* ****************** ******************** * hidden_states * *condition_states1* * condition_states2* ************************* ****************** ******************** """ hidden_length = hidden_states.shape[1] condition1_length = condition_states1.shape[1] condition2_length = condition_states2.shape[1] # norm norm_hidden_states,gate = self.norm1(hidden_states, temb=temb) norm_condition_states1 = self.norm1_condition1(condition_states1) norm_condition_states2 = self.norm1_condition2(condition_states2) # AMD uses concatenated image + motion embeddings with self-attention instead of using # them in cross-attention individually norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + gate * attn_output[:,:hidden_length] condition_states1 = condition_states1 + attn_output[:, hidden_length:hidden_length+condition1_length] condition_states2 = condition_states2 + attn_output[:, hidden_length+condition1_length:] # norm & modulate norm_hidden_states,gate = self.norm2(hidden_states, temb=temb) norm_condition_states1 = self.norm2_condition1(condition_states1) norm_condition_states2 = self.norm2_condition2(condition_states2) # feed-forward norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate * ff_output[:, :hidden_length] condition_states1 = condition_states1 + ff_output[:, hidden_length:hidden_length+condition1_length] condition_states2 = condition_states2 + ff_output[:, hidden_length+condition1_length:] return hidden_states, condition_states1,condition_states2 class Any2MotionTransformerBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, motion_frames : int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() self.motion_frames = motion_frames # 1.1 norm_in self.norm1 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) # 1.2 self-attention self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2.1 norm self.norm2 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) # 2.2 cross-attention for refimg self.attn2 = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, out_bias=attention_out_bias, ) # 3.1 norm self.norm3 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) # 3.2 cross-attention for extra-condition self.attn3 = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, out_bias=attention_out_bias, ) # 4. Feed Forward self.norm4 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, refimg_states: torch.Tensor, extra_states: torch.Tensor, temb: torch.Tensor, ) -> torch.Tensor: assert hidden_states.dim() == refimg_states.dim() and hidden_states.dim() == extra_states.dim() , f"hidden_states.dim():{hidden_states.dim()},refimg_states.dim():{refimg_states.dim()},extra_states.dim():{extra_states.dim()}" # 1.1 norm hidden_states = self.norm1(hidden_states, temb=temb) # 1.2 3D self-attention hidden_states = einops.rearrange(hidden_states, '(b f) l d -> b (f l) d',f=self.motion_frames) attn_output = self.attn1(hidden_states, None) hidden_states = hidden_states + attn_output hidden_states = einops.rearrange(hidden_states, 'b (f l) d -> (b f) l d',f=self.motion_frames) # 2.1 norm hidden_states = self.norm2(hidden_states, temb=temb) # 2.2 cross-attention for refimg attn_output = self.attn2(hidden_states, refimg_states) # 3.1 norm hidden_states = hidden_states + attn_output hidden_states = self.norm3(hidden_states, temb=temb) # 3.2 cross-attention for extra-condition attn_output = self.attn3(hidden_states, extra_states) # 4.1 norm hidden_states = hidden_states + attn_output hidden_states = self.norm4(hidden_states, temb=temb) # 4.2 ff hidden_states = self.ff(hidden_states) + hidden_states return hidden_states class A2MCrossAttnBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1.1 norm_in self.norm1 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) # 2.2 cross-attention for refimg self.attn = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, out_bias=attention_out_bias, ) # Feed Forward self.norm2 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, motion_hidden_states: torch.Tensor, # N,FL,D ref_motion_hidden_states: torch.Tensor, # N,TL,D conditon_hidden_states: torch.Tensor, # N,T+F,W,D temb: torch.Tensor, ) -> torch.Tensor: N,FL,D = motion_hidden_states.shape N,TL,D = ref_motion_hidden_states.shape N,T_F,W,D = conditon_hidden_states.shape L = (FL + TL)//T_F if conditon_hidden_states.dim()==4 : conditon_hidden_states = einops.rearrange(conditon_hidden_states,'n f w d -> (n f) w d') # N(T+F),W,D # norm1 norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm1( motion_hidden_states, ref_motion_hidden_states, temb ) # transform for cross attn hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) # N,TL+FL,D hidden_states = einops.rearrange(hidden_states,'n (f l) d -> (n f) l d',l=L) # N(T+F),L,D assert hidden_states.shape[0] == conditon_hidden_states.shape[0] ,f'hidden_states.shape {hidden_states.shape} ,audio_hidden_states.shape {audio_hidden_states.shape}' # cross-attention for audio attn_output = self.attn(hidden_states, conditon_hidden_states) # N(T+F),L,D attn_output = einops.rearrange(attn_output,'(n f) l d -> n f l d',n=N).flatten(1,2) # N,TL+FL,D motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,TL:] # N,FL,D ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * attn_output[:,:TL] # N,L,D # norm2 norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm2( motion_hidden_states, ref_motion_hidden_states, temb ) hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) # N,L+FL,D # ff hidden_states = self.ff(hidden_states) # N,TL+FL,D motion_hidden_states = motion_hidden_states + gate_msa * hidden_states[:,TL:] # N,FL,D ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * hidden_states[:,:TL] # N,TL,D return motion_hidden_states,ref_motion_hidden_states class A2MMotionSelfAttnBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1.1 norm_in self.norm1 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) # 1.2 self-attention self.attn = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2.1 norm self.norm2 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, motion_hidden_states: torch.Tensor, # N,FL,D ref_motion_hidden_states: torch.Tensor, # N,TL,D temb: torch.Tensor, ) -> torch.Tensor: N,FL,D = motion_hidden_states.shape N,TL,D = ref_motion_hidden_states.shape # norm1 norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm1( motion_hidden_states, ref_motion_hidden_states, temb ) hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) # N,TL+FL,D # self-attention attn_output = self.attn(hidden_states, None) motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,TL:] # N,FL,D ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * attn_output[:,:TL] # N,TL,D # norm2 norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm2( motion_hidden_states, ref_motion_hidden_states, temb ) hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) # N,TL+FL,D # ff hidden_states = self.ff(hidden_states) # N,TL+FL,D motion_hidden_states = motion_hidden_states + gate_msa * hidden_states[:,TL:] # N,FL,D ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * hidden_states[:,:TL] # N,TL,D return motion_hidden_states,ref_motion_hidden_states class A2MMotionSelfAttnBlockDoubleRef(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, last_layer:bool = False, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() self.last_layer = last_layer # 1.1 norm_in self.norm1 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) # 1.2 self-attention self.attn = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2.1 norm self.norm2 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) self.attn2 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) self.norm3 = AMDLayerNormZero(conditioning_dim=time_embed_dim, embedding_dim=dim) if self.last_layer: self.norm4 = None else: self.norm4 = AMDLayerNormZero_OneVariable(conditioning_dim=time_embed_dim, embedding_dim=dim) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) if not self.last_layer: self.ff2 = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) else: self.ff2 = None def forward( self, motion_hidden_states: torch.Tensor, # N,FL,D ref_motion_hidden_states: torch.Tensor, # N,TL,D randomref_motion_hidden_states: torch.Tensor, # N,SL,D temb: torch.Tensor, motion_token: int = 1, ) -> torch.Tensor: N,FL,D = motion_hidden_states.shape N,TL,D = ref_motion_hidden_states.shape NF,SL,D = randomref_motion_hidden_states.shape L = motion_token T = TL // L F = FL // L assert F == NF // N # randomref_motion_hidden_states = torch.repeat_interleave(F, dim=0) # NF,SL,D # norm1 norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm1( motion_hidden_states, ref_motion_hidden_states, temb ) hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) # N,TL+FL,D # self-attention attn_output = self.attn(hidden_states, None) motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,TL:] # N,FL,D ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * attn_output[:,:TL] # N,TL,D # norm2 motion_hidden_states = einops.rearrange(motion_hidden_states,"n (f l) d -> (n f) l d",l=L) # NF,L,D flat_temb = temb.repeat_interleave(F,dim=0) # NF,D norm_motion_hidden_states, norm_randomref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm2( motion_hidden_states, randomref_motion_hidden_states, flat_temb ) hidden_states = torch.cat([norm_randomref_motion_hidden_states,norm_motion_hidden_states],dim=1) # NF,SL+L,D # self-attention2 attn_output = self.attn2(hidden_states, None) motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,SL:] # NF,L,D randomref_motion_hidden_states = randomref_motion_hidden_states + enc_gate_msa * attn_output[:,:SL] # NF,SL,D # norm3 motion_hidden_states = einops.rearrange(motion_hidden_states,"(n f) l d -> n (f l) d",n=N) # N,FL,D norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm3( motion_hidden_states, ref_motion_hidden_states, temb ) # ff hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) # N,TL+FL,D hidden_states = self.ff(hidden_states) # N,TL+FL,D motion_hidden_states = motion_hidden_states + gate_msa * hidden_states[:,TL:] # N,FL,D ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * hidden_states[:,:TL] # N,TL,D # ff2 if not self.last_layer: norm_randomref_motion_hidden_states,gate_msa_r = self.norm4(randomref_motion_hidden_states,flat_temb) norm_randomref_motion_hidden_states = self.ff2(norm_randomref_motion_hidden_states) # NF,SL,D randomref_motion_hidden_states = randomref_motion_hidden_states + gate_msa_r * norm_randomref_motion_hidden_states# NF,SL,D return motion_hidden_states,ref_motion_hidden_states,randomref_motion_hidden_states # ----------------------- A2M audio ------------------------ class AudioToImageShapeMlp(nn.Module): def __init__(self, audio_dim:int = 384, audio_block:int = 50, outchannel:int = 256, out_height:int = 4, out_width:int = 4, **kwargs ): super().__init__() self.outchannel = outchannel self.out_height = out_height self.out_width = out_width outdim = outchannel * out_height * out_width self.mlp = Mlp(in_features=audio_dim*audio_block,hidden_features=outdim,out_features=outdim) def forward(self,audio_feature:torch.Tensor): """ Args: audio_feature (torch.Tensor): (N,F,M,C) Returns: audio_feature (torch.Tensor): (N,F,D) """ n,f,m,d = audio_feature.shape audio_feature = einops.rearrange(audio_feature,'n f m d -> n f (m d)') audio_feature = self.mlp(audio_feature) audio_feature = einops.rearrange(audio_feature,'n f (c h w) -> n f c h w',c=self.outchannel,h=self.out_height,w=self.out_width) return audio_feature class AudioFeatureMlp(nn.Module): def __init__(self, audio_dim:int = 384, audio_block:int = 50, hidden_dim:int = 128, outdim:int = 1024, **kwargs ): super().__init__() # self.mlp1 = Mlp(in_features=audio_dim, # hidden_features=hidden_dim, # out_features=hidden_dim) # self.mlp2 = Mlp(in_features=audio_block * hidden_dim, # hidden_features=outdim, # out_features=outdim) self.mlp = Mlp(in_features=audio_dim*audio_block,hidden_features=outdim,out_features=outdim) def forward(self,audio_feature:torch.Tensor): """ Args: audio_feature (torch.Tensor): (N,F,M,C) Returns: audio_feature (torch.Tensor): (N,F,D) """ # n,f,m,d = audio_feature.shape # audio_feature = self.mlp1(audio_feature) # audio_feature = audio_feature.reshape(n,f,-1) # audio_feature = self.mlp2(audio_feature) audio_feature = einops.rearrange(audio_feature,'n f m d -> n f (m d)') audio_feature = self.mlp(audio_feature) return audio_feature class AudioFeatureWindowMlp(nn.Module): def __init__(self, audio_dim:int = 384, audio_block:int = 50, intermediate_dim :int = 1024, window_size:int = 12, outdim:int = 768, **kwargs ): super().__init__() self.window_size = window_size self.ff1 = nn.Linear(audio_dim*audio_block, intermediate_dim) self.ff2 = nn.Linear(intermediate_dim,intermediate_dim) self.ff3 = nn.Linear(intermediate_dim,window_size * outdim) self.norm = nn.LayerNorm(outdim) def forward(self,audio_feature:torch.Tensor): """ Args: audio_feature (torch.Tensor): (N,F,M,C) Returns: audio_feature (torch.Tensor): (N,F,W,D) """ n,f,m,d = audio_feature.shape audio_feature = einops.rearrange(audio_feature,'n f m d -> n f (m d)') audio_feature = torch.relu(self.ff1(audio_feature)) # n f inter audio_feature = torch.relu(self.ff2(audio_feature)) # n f inter audio_feature = torch.relu(self.ff3(audio_feature)) # n f w*d audio_feature = einops.rearrange(audio_feature,"n f (w d) -> n f w d",w= self.window_size) audio_feature = self.norm(audio_feature) return audio_feature class RefMotionRefImgeBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = True, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1.1 norm_in self.norm1 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) # 1.2 self-attention self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2.1 norm self.norm2 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) # 2.2 cross-attention for refmotion self.attn2 = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, out_bias=attention_out_bias, ) # 3.1 norm self.norm3 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) # 3.2 cross-attention for refimg self.attn3 = Attention( query_dim=dim, cross_attention_dim=dim, heads=num_attention_heads, dim_head=attention_head_dim, dropout=dropout, bias=attention_bias, out_bias=attention_out_bias, ) # 4. Feed Forward self.norm4 = AdaLayerNorm( embedding_dim=time_embed_dim, output_dim=dim*2, norm_elementwise_affine=norm_elementwise_affine, norm_eps=norm_eps, chunk_dim=1 ) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, # N,L1,D refmotion_states: torch.Tensor, # N,L2,D refimg_states: torch.Tensor, # N,L3,D temb: torch.Tensor, ) -> torch.Tensor: assert hidden_states.dim() == refimg_states.dim() and hidden_states.dim() == refmotion_states.dim() , f"hidden_states.dim():{hidden_states.dim()},refimg_states.dim():{refimg_states.dim()}" # 1.1 norm hidden_states = self.norm1(hidden_states, temb=temb) # 1.2 3D self-attention attn_output = self.attn1(hidden_states, None) hidden_states = hidden_states + attn_output # 2.1 norm hidden_states = self.norm2(hidden_states, temb=temb) # 2.2 cross-attention for refmotion attn_output = self.attn2(hidden_states, refmotion_states) # 3.1 norm hidden_states = hidden_states + attn_output hidden_states = self.norm3(hidden_states, temb=temb) # 3.2 cross-attention for refimg attn_output = self.attn3(hidden_states, refimg_states) # 4.1 norm hidden_states = hidden_states + attn_output hidden_states = self.norm4(hidden_states, temb=temb) # 4.2 ff hidden_states = self.ff(hidden_states) + hidden_states return hidden_states class Audio2Pose(nn.Module): def __init__(self, audio_dim:int = 384, audio_block:int = 50, motion_height:int = 4, motion_width:int = 4, motion_dim:int = 256, pose_width:int = 32, pose_height:int = 32, pose_dim:int = 4, num_frames:int = 15, **kwargs ): super().__init__() self.num_frames = num_frames self.pw = pose_width self.ph = pose_height self.pc = pose_dim self.audio_encoder = AudioToImageShapeMlp( audio_dim=audio_dim, audio_block = audio_block, outchannel=motion_dim, out_height=motion_height, out_width=motion_width, ) # (NF,256,4,4) self.pose_predictor = Upsampler( in_channel=motion_dim, out_channel=pose_dim, block_out_channels=(motion_dim,128,64,32), ) self.pose_downsample = DownEncoder(in_channel=pose_dim,block_out_channels=(32,64,128,motion_dim)) def forward(self,audio_feature:torch.Tensor,pose_gt:torch.Tensor): """ Args: audio_feature (torch.Tensor): (N,F,M,D) pose_gt (torch.Tensor): (N,F,C,H,W) Returns: pose_pred (torch.Tensor): (N,F,C,H,W), used for loss calculation pose_transform (torch.Tensor): (N,F,256,4,4), used for diffusion audio_hidden_state (torch.Tensor): (N,F,256,4,4), used for audio condition injection """ b,f,m,d = audio_feature.shape audio_hidden_state = self.audio_encoder(audio_feature) # (N,F,256,4,4) audio_hidden_state = einops.rearrange(audio_hidden_state,'n f c h w -> (n f) c h w') pose_pre = self.pose_predictor(audio_hidden_state) pose_gt = einops.rearrange(pose_gt,'n f c h w -> (n f) c h w') pose_gt_transform = self.pose_downsample(pose_gt) pose_pre = einops.rearrange(pose_pre,'(n f) c h w -> n f c h w',n=b) # (8,15,256,4,4) pose_gt_transform = einops.rearrange(pose_gt_transform,'(n f) c h w -> n f c h w',n=b) # (4,15,4,32,32) audio_hidden_state = einops.rearrange(audio_hidden_state,'(n f) c h w -> n f c h w',n=b) # (4,15,256,4,4) return pose_pre, pose_gt_transform, audio_hidden_state def prepare_extra(self,audio:torch.Tensor,pose:torch.Tensor): b = audio.shape[0] audio_hidden_state = self.audio_encoder(audio) pose_pred = self.pose_predictor(audio_hidden_state) pose_pred = self.pose_downsample(pose_pred) pose_pred = einops.rearrange(pose_pred,'(n f) c h w -> n f c h w',n=b) # (4,15,4,32,32) audio_hidden_state = einops.rearrange(audio_hidden_state,'(n f) c h w -> n f c h w',n=b) # (4,15,256,4,4) return audio_hidden_state, pose_pred class MotionTrensferBlock(nn.Module): r""" MotionTrensferBlock """ def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: int, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # 1. Self Attention self.norm1 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # 2. Feed Forward self.norm2 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, # NF,L1,D encoder_hidden_states: torch.Tensor, # NF,L2,D temb: torch.Tensor, ) -> torch.Tensor: """ ************************* ****************** * hidden_states* * encoder_hidden_states * ************************* ****************** """ norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( hidden_states, encoder_hidden_states, temb ) # attention motion_length = norm_hidden_states.shape[1] # AMD uses concatenated image + motion embeddings with self-attention instead of using # them in cross-attention individually norm_hidden_states = torch.cat([norm_hidden_states, norm_encoder_hidden_states ], dim=1) attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + gate_msa * attn_output[:, :motion_length] encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, motion_length:] # norm & modulate norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( hidden_states, encoder_hidden_states, temb ) # feed-forward norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + gate_ff * ff_output[:, :motion_length] encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, motion_length:] return hidden_states, encoder_hidden_states # ----------------------- A2P ------------------------ class A2PTemporalSpatialBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: Optional[int] = None, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # Temporal Attention self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # Spatial Attention self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.attn2 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # Feed Forward self.norm3 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, ) -> torch.Tensor: N,F,L,D = hidden_states.shape # norm1 hidden_states = einops.rearrange(hidden_states,'n f l d -> (n l) f d') # NL,F,D norm_hidden_states = self.norm1(hidden_states) # NL,F,D # temporal attention attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + attn_output # norm2 hidden_states = einops.rearrange(hidden_states,'(n l) f d -> (n f) l d',n=N,l=L) norm_hidden_states = self.norm2(hidden_states) # NF,L,D # spatial attention attn_output = self.attn2( hidden_states=norm_hidden_states, encoder_hidden_states=None, ) hidden_states = hidden_states + attn_output # norm & modulate norm_hidden_states = self.norm3(hidden_states) # NF,L,D # feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + ff_output # transform hidden_states = einops.rearrange(hidden_states,'(n f) l d -> n f l d',n=N) return hidden_states # N,F,L,D class A2PCrossAudioBlock(nn.Module): def __init__( self, dim: int, num_attention_heads: int, attention_head_dim: int, time_embed_dim: Optional[int] = None, dropout: float = 0.0, activation_fn: str = "gelu-approximate", attention_bias: bool = False, qk_norm: bool = True, norm_elementwise_affine: bool = True, norm_eps: float = 1e-5, final_dropout: bool = True, ff_inner_dim: Optional[int] = None, ff_bias: bool = True, attention_out_bias: bool = True, ): super().__init__() # Temporal Attention self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.attn1 = Attention( query_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, qk_norm="layer_norm" if qk_norm else None, eps=1e-6, bias=attention_bias, out_bias=attention_out_bias, ) # Feed Forward self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) self.ff = FeedForward( dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout, inner_dim=ff_inner_dim, bias=ff_bias, ) def forward( self, hidden_states: torch.Tensor, audio_hidden_states: torch.Tensor, ) -> torch.Tensor: N,F,L,D = hidden_states.shape N,F,W,D = audio_hidden_states.shape # norm1 hidden_states = einops.rearrange(hidden_states,'n f l d -> (n f) l d') # NF,L,D norm_hidden_states = self.norm1(hidden_states) # NF,L,D audio_hidden_states = einops.rearrange(audio_hidden_states,'n f w d -> (n f) w d') # NF,W,D # temporal attention attn_output = self.attn1( hidden_states=norm_hidden_states, encoder_hidden_states=audio_hidden_states, ) hidden_states = hidden_states + attn_output # NF,L,D # norm & modulate norm_hidden_states = self.norm2(hidden_states) # NF,L,D # feed-forward ff_output = self.ff(norm_hidden_states) hidden_states = hidden_states + ff_output # transform hidden_states = einops.rearrange(hidden_states,'(n f) l d -> n f l d',n=N,f=F) return hidden_states # N,F,L,D