|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import math |
|
|
from torch import nn |
|
|
from torch.nn import functional as F |
|
|
import einops |
|
|
from timm.models.layers import Mlp |
|
|
|
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.utils import BaseOutput, logging |
|
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from diffusers.models.resnet import Downsample2D, ResnetBlock2D |
|
|
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D |
|
|
from diffusers.models.upsampling import Upsample2D |
|
|
from diffusers.utils import deprecate, is_torch_version |
|
|
from einops import rearrange |
|
|
from diffusers.models.attention import Attention,FeedForward |
|
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed |
|
|
from diffusers.models.embeddings import get_2d_sincos_pos_embed,get_1d_sincos_pos_embed_from_grid,get_3d_sincos_pos_embed,TimestepEmbedding, Timesteps |
|
|
|
|
|
VIS_ATTEN_FLAG = False |
|
|
attention_maps = [] |
|
|
def get_attention_maps(): |
|
|
global attention_maps |
|
|
return attention_maps |
|
|
def clear_attention_maps(): |
|
|
global attention_maps |
|
|
attention_maps.clear() |
|
|
def set_vis_atten_flag(flag): |
|
|
global VIS_ATTEN_FLAG |
|
|
VIS_ATTEN_FLAG = flag |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class DownEncoderBlock2D(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
dropout: float = 0.0, |
|
|
num_layers: int = 1, |
|
|
resnet_eps: float = 1e-6, |
|
|
resnet_time_scale_shift: str = "default", |
|
|
resnet_act_fn: str = "swish", |
|
|
resnet_groups: int = 32, |
|
|
resnet_pre_norm: bool = True, |
|
|
output_scale_factor: float = 1.0, |
|
|
add_downsample: bool = True, |
|
|
downsample_padding: int = 1, |
|
|
): |
|
|
super().__init__() |
|
|
resnets = [] |
|
|
|
|
|
for i in range(num_layers): |
|
|
in_channels = in_channels if i == 0 else out_channels |
|
|
resnets.append( |
|
|
ResnetBlock2D( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=None, |
|
|
eps=resnet_eps, |
|
|
groups=resnet_groups, |
|
|
dropout=dropout, |
|
|
time_embedding_norm=resnet_time_scale_shift, |
|
|
non_linearity=resnet_act_fn, |
|
|
output_scale_factor=output_scale_factor, |
|
|
pre_norm=resnet_pre_norm, |
|
|
) |
|
|
) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_downsample: |
|
|
self.downsamplers = nn.ModuleList( |
|
|
[ |
|
|
Downsample2D( |
|
|
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" |
|
|
) |
|
|
] |
|
|
) |
|
|
else: |
|
|
self.downsamplers = None |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: |
|
|
if len(args) > 0 or kwargs.get("scale", None) is not None: |
|
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." |
|
|
deprecate("scale", "1.0.0", deprecation_message) |
|
|
|
|
|
for resnet in self.resnets: |
|
|
hidden_states = resnet(hidden_states, temb=None) |
|
|
|
|
|
if self.downsamplers is not None: |
|
|
for downsampler in self.downsamplers: |
|
|
hidden_states = downsampler(hidden_states) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class MidBlock2D(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channel: int = 64, |
|
|
out_channel: int = 1280, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.mid_convs = nn.ModuleList() |
|
|
self.mid_convs.append(nn.Sequential( |
|
|
nn.Conv2d( |
|
|
in_channels=in_channel, |
|
|
out_channels=in_channel, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1 |
|
|
), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d( |
|
|
in_channels=in_channel, |
|
|
out_channels=in_channel, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1 |
|
|
), |
|
|
)) |
|
|
self.mid_convs.append(nn.Conv2d( |
|
|
in_channels=in_channel, |
|
|
out_channels=out_channel, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
)) |
|
|
|
|
|
def forward(self, x): |
|
|
for mid_conv in self.mid_convs: |
|
|
sample = mid_conv(x) |
|
|
|
|
|
return sample |
|
|
|
|
|
class UpDecoderBlock2D(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
resolution_idx: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
num_layers: int = 1, |
|
|
resnet_eps: float = 1e-6, |
|
|
resnet_time_scale_shift: str = "default", |
|
|
resnet_act_fn: str = "swish", |
|
|
resnet_groups: int = 32, |
|
|
resnet_pre_norm: bool = True, |
|
|
output_scale_factor: float = 1.0, |
|
|
add_upsample: bool = True, |
|
|
temb_channels: Optional[int] = None, |
|
|
): |
|
|
super().__init__() |
|
|
resnets = [] |
|
|
|
|
|
for i in range(num_layers): |
|
|
input_channels = in_channels if i == 0 else out_channels |
|
|
resnets.append( |
|
|
ResnetBlock2D( |
|
|
in_channels=input_channels, |
|
|
out_channels=out_channels, |
|
|
temb_channels=temb_channels, |
|
|
eps=resnet_eps, |
|
|
groups=resnet_groups, |
|
|
dropout=dropout, |
|
|
time_embedding_norm=resnet_time_scale_shift, |
|
|
non_linearity=resnet_act_fn, |
|
|
output_scale_factor=output_scale_factor, |
|
|
pre_norm=resnet_pre_norm, |
|
|
) |
|
|
) |
|
|
|
|
|
self.resnets = nn.ModuleList(resnets) |
|
|
|
|
|
if add_upsample: |
|
|
self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) |
|
|
else: |
|
|
self.upsamplers = None |
|
|
|
|
|
self.resolution_idx = resolution_idx |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
for resnet in self.resnets: |
|
|
hidden_states = resnet(hidden_states, temb=temb) |
|
|
|
|
|
if self.upsamplers is not None: |
|
|
for upsampler in self.upsamplers: |
|
|
hidden_states = upsampler(hidden_states) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class DuoFrameDownEncoder(ModelMixin, ConfigMixin): |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channel: int = 4, |
|
|
block_out_channels : Tuple[int] = (64, 128, 256, 256), |
|
|
norm_groups : int = 4, |
|
|
resnet_layers_per_block: int = 2, |
|
|
add_attention : bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d( |
|
|
in_channel, |
|
|
block_out_channels[0], |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
|
|
|
self.downblock = nn.ModuleList() |
|
|
|
|
|
output_channel = block_out_channels[0] |
|
|
for i,channels in enumerate(block_out_channels): |
|
|
|
|
|
input_channel = output_channel |
|
|
output_channel = block_out_channels[i] |
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
|
|
self.downblock.append( |
|
|
DownEncoderBlock2D( |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
num_layers= resnet_layers_per_block, |
|
|
resnet_groups = norm_groups, |
|
|
add_downsample=not is_final_block, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock2D( |
|
|
in_channels=block_out_channels[-1], |
|
|
resnet_eps=1e-6, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default", |
|
|
attention_head_dim=block_out_channels[-1], |
|
|
resnet_groups=norm_groups, |
|
|
temb_channels=None, |
|
|
add_attention=add_attention, |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) |
|
|
self.conv_act = nn.SiLU() |
|
|
self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) |
|
|
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if hasattr(module, "gradient_checkpointing"): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def forward(self, x: torch.FloatTensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
* x : (b,c,h,w) |
|
|
Output: |
|
|
* x : (b,c',h/8,w/8) |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.conv_in(x) |
|
|
|
|
|
|
|
|
for downblock in self.downblock: |
|
|
x = downblock(x) |
|
|
|
|
|
|
|
|
x = self.mid_block(x) |
|
|
|
|
|
|
|
|
x = self.conv_norm_out(x) |
|
|
x = self.conv_act(x) |
|
|
x = self.conv_out(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class MotionDownEncoder(ModelMixin, ConfigMixin): |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channel: int = 4, |
|
|
block_out_channels : Tuple[int] = (64, 128, 256, 256), |
|
|
norm_groups : int = 32, |
|
|
resnet_layers_per_block: int = 2, |
|
|
add_attention : bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d( |
|
|
in_channel, |
|
|
block_out_channels[0], |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
) |
|
|
|
|
|
|
|
|
self.downblock = nn.ModuleList() |
|
|
|
|
|
output_channel = block_out_channels[0] |
|
|
for i,channels in enumerate(block_out_channels): |
|
|
|
|
|
input_channel = output_channel |
|
|
output_channel = block_out_channels[i] |
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
|
|
self.downblock.append( |
|
|
DownEncoderBlock2D( |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
num_layers= resnet_layers_per_block, |
|
|
resnet_groups = norm_groups, |
|
|
add_downsample=not is_final_block, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock2D( |
|
|
in_channels=block_out_channels[-1], |
|
|
resnet_eps=1e-6, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default", |
|
|
attention_head_dim=block_out_channels[-1], |
|
|
resnet_groups=norm_groups, |
|
|
temb_channels=None, |
|
|
add_attention=add_attention, |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) |
|
|
self.conv_act = nn.SiLU() |
|
|
self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) |
|
|
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if hasattr(module, "gradient_checkpointing"): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def forward(self, x: torch.FloatTensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
* x : (b,c,h,w) |
|
|
Output: |
|
|
* x : (b,c',h/8,w/8) |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.conv_in(x) |
|
|
|
|
|
|
|
|
for downblock in self.downblock: |
|
|
x = downblock(x) |
|
|
|
|
|
|
|
|
x = self.mid_block(x) |
|
|
|
|
|
|
|
|
x = self.conv_norm_out(x) |
|
|
x = self.conv_act(x) |
|
|
x = self.conv_out(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class DownEncoder(ModelMixin, ConfigMixin): |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channel: int = 4, |
|
|
block_out_channels : Tuple[int] = (64, 128, 256, 256), |
|
|
norm_groups : int = 8, |
|
|
resnet_layers_per_block: int = 2, |
|
|
add_attention : bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d( |
|
|
in_channel, |
|
|
block_out_channels[0], |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
) |
|
|
|
|
|
|
|
|
self.downblock = nn.ModuleList() |
|
|
|
|
|
output_channel = block_out_channels[0] |
|
|
for i,channels in enumerate(block_out_channels): |
|
|
|
|
|
input_channel = output_channel |
|
|
output_channel = block_out_channels[i] |
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
|
|
self.downblock.append( |
|
|
DownEncoderBlock2D( |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
num_layers= resnet_layers_per_block, |
|
|
resnet_groups = norm_groups, |
|
|
add_downsample=not is_final_block, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock2D( |
|
|
in_channels=block_out_channels[-1], |
|
|
resnet_eps=1e-6, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default", |
|
|
attention_head_dim=block_out_channels[-1], |
|
|
resnet_groups=norm_groups, |
|
|
temb_channels=None, |
|
|
add_attention=add_attention, |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) |
|
|
self.conv_act = nn.SiLU() |
|
|
self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) |
|
|
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if hasattr(module, "gradient_checkpointing"): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def forward(self, x: torch.FloatTensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
* x : (b,c,h,w) |
|
|
Output: |
|
|
* x : (b,c',h/8,w/8) |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.conv_in(x) |
|
|
|
|
|
|
|
|
for downblock in self.downblock: |
|
|
x = downblock(x) |
|
|
|
|
|
|
|
|
x = self.mid_block(x) |
|
|
|
|
|
|
|
|
x = self.conv_norm_out(x) |
|
|
x = self.conv_act(x) |
|
|
x = self.conv_out(x) |
|
|
|
|
|
return x |
|
|
|
|
|
class Upsampler(ModelMixin, ConfigMixin): |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channel: int = 256, |
|
|
out_channel: Optional[int] = None, |
|
|
block_out_channels : Tuple[int] = (256, 256, 128, 64), |
|
|
norm_groups : int = 8, |
|
|
resnet_layers_per_block: int = 2, |
|
|
add_attention : bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.out_channel = out_channel |
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d( |
|
|
in_channel, |
|
|
block_out_channels[0], |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock2D( |
|
|
in_channels=block_out_channels[0], |
|
|
resnet_eps=1e-6, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default", |
|
|
attention_head_dim=block_out_channels[0], |
|
|
resnet_groups=norm_groups, |
|
|
temb_channels=None, |
|
|
add_attention=add_attention, |
|
|
) |
|
|
|
|
|
|
|
|
self.upblock = nn.ModuleList() |
|
|
output_channel = block_out_channels[0] |
|
|
for i,channels in enumerate(block_out_channels): |
|
|
|
|
|
input_channel = output_channel |
|
|
output_channel = block_out_channels[i] |
|
|
is_final_block = i == len(block_out_channels) - 1 |
|
|
|
|
|
self.upblock.append( |
|
|
UpDecoderBlock2D( |
|
|
in_channels=input_channel, |
|
|
out_channels=output_channel, |
|
|
num_layers= resnet_layers_per_block, |
|
|
resnet_groups = norm_groups, |
|
|
add_upsample=not is_final_block, |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_groups, eps=1e-6) |
|
|
self.conv_act = nn.SiLU() |
|
|
self.conv_out = nn.Conv2d(block_out_channels[-1], block_out_channels[-1], 3, padding=1) |
|
|
|
|
|
|
|
|
if self.out_channel: |
|
|
self.conv_final = nn.Conv2d( |
|
|
block_out_channels[-1], |
|
|
out_channel, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if hasattr(module, "gradient_checkpointing"): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def forward(self, x: torch.FloatTensor) -> torch.tensor: |
|
|
""" |
|
|
Args: |
|
|
* x : (b,c,h,w) |
|
|
Output: |
|
|
* x : (b,c',h*8,w*8) |
|
|
""" |
|
|
|
|
|
|
|
|
x = self.conv_in(x) |
|
|
|
|
|
|
|
|
x = self.mid_block(x) |
|
|
|
|
|
|
|
|
for upblock in self.upblock: |
|
|
x = upblock(x) |
|
|
|
|
|
|
|
|
x = self.conv_norm_out(x) |
|
|
x = self.conv_act(x) |
|
|
x = self.conv_out(x) |
|
|
|
|
|
|
|
|
if self.out_channel: |
|
|
x = self.conv_final(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class MapConv(nn.Module): |
|
|
def __init__(self, |
|
|
in_channel: int = 8, |
|
|
hidden : int = 640, |
|
|
out_channel: int = 4, |
|
|
block_layer : int = 8, |
|
|
goups : int = 2,): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.conv_in = nn.Conv2d( |
|
|
in_channel, |
|
|
hidden, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
|
|
|
self.mid_block = UNetMidBlock2D( |
|
|
in_channels=hidden, |
|
|
resnet_eps=1e-6, |
|
|
output_scale_factor=1, |
|
|
resnet_time_scale_shift="default", |
|
|
attention_head_dim=64, |
|
|
resnet_groups=goups, |
|
|
temb_channels=None, |
|
|
add_attention=True, |
|
|
) |
|
|
|
|
|
|
|
|
self.map = nn.ModuleList() |
|
|
for i in range(block_layer): |
|
|
resnet = ResnetBlock2D( |
|
|
in_channels=hidden, |
|
|
out_channels=hidden, |
|
|
temb_channels=None, |
|
|
groups=goups, |
|
|
) |
|
|
self.map.append(resnet) |
|
|
|
|
|
|
|
|
self.conv_out = nn.Conv2d( |
|
|
hidden, |
|
|
out_channel, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
padding=1, |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.tensor , temb: Optional[torch.tensor] = None) -> torch.tensor: |
|
|
|
|
|
x = self.conv_in(x) |
|
|
x = self.mid_block(x) |
|
|
for l in self.map: |
|
|
x = l(x,None) |
|
|
x = self.conv_out(x) |
|
|
|
|
|
return x |
|
|
def simple_attention_processor( |
|
|
attn: Attention, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
): |
|
|
residual = hidden_states |
|
|
if attn.spatial_norm is not None: |
|
|
hidden_states = attn.spatial_norm(hidden_states, None) |
|
|
input_ndim = hidden_states.ndim |
|
|
if input_ndim == 4: |
|
|
batch_size, channel, height, width = hidden_states.shape |
|
|
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) |
|
|
batch_size, sequence_length, _ = ( |
|
|
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape |
|
|
) |
|
|
if attention_mask is not None: |
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) |
|
|
|
|
|
|
|
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) |
|
|
if attn.group_norm is not None: |
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) |
|
|
query = attn.to_q(hidden_states) |
|
|
if encoder_hidden_states is None: |
|
|
encoder_hidden_states = hidden_states |
|
|
elif attn.norm_cross: |
|
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) |
|
|
key = attn.to_k(encoder_hidden_states) |
|
|
value = attn.to_v(encoder_hidden_states) |
|
|
inner_dim = key.shape[-1] |
|
|
head_dim = inner_dim // attn.heads |
|
|
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) |
|
|
if attn.norm_q is not None: |
|
|
query = attn.norm_q(query) |
|
|
if attn.norm_k is not None: |
|
|
key = attn.norm_k(key) |
|
|
|
|
|
|
|
|
score = torch.einsum("bhld,bhsd->bhls", query, key) / math.sqrt(head_dim) |
|
|
return score.softmax(dim=-1) |
|
|
|
|
|
class BasicTransformerBlock(nn.Module): |
|
|
r""" |
|
|
Parameters: |
|
|
dim (`int`): The number of channels in the input and output. |
|
|
num_attention_heads (`int`): The number of heads to use for multi-head attention. |
|
|
attention_head_dim (`int`): The number of channels in each head. |
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. |
|
|
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. |
|
|
attention_bias (: |
|
|
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. |
|
|
qk_norm (`bool`, defaults to `True`): |
|
|
Whether or not to use normalization after query and key projections in Attention. |
|
|
norm_elementwise_affine (`bool`, defaults to `True`): |
|
|
Whether to use learnable elementwise affine parameters for normalization. |
|
|
norm_eps (`float`, defaults to `1e-5`): |
|
|
Epsilon value for normalization layers. |
|
|
final_dropout (`bool` defaults to `False`): |
|
|
Whether to apply a final dropout after the last feed-forward layer. |
|
|
ff_inner_dim (`int`, *optional*, defaults to `None`): |
|
|
Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used. |
|
|
ff_bias (`bool`, defaults to `True`): |
|
|
Whether or not to use bias in Feed-forward layer. |
|
|
attention_out_bias (`bool`, defaults to `True`): |
|
|
Whether or not to use bias in Attention output projection layer. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = hidden_states + attn_output |
|
|
|
|
|
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
|
|
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
patch_size: int = 2, |
|
|
in_channels: int = 16, |
|
|
embed_dim: int = 1920, |
|
|
bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
|
|
|
self.proj = nn.Conv2d( |
|
|
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias |
|
|
) |
|
|
|
|
|
def forward(self, image_embeds: torch.Tensor): |
|
|
r""" |
|
|
Args: |
|
|
image_embeds (`torch.Tensor`): |
|
|
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width) or (batch_size, channels, height, width) |
|
|
Returns: |
|
|
embeds (`torch.Tensor`): |
|
|
(batch_size,num_frames x height x width,embed_dim) or (batch_size,1 x height x width,embed_dim) |
|
|
""" |
|
|
if image_embeds.dim() == 5: |
|
|
batch, num_frames, channels, height, width = image_embeds.shape |
|
|
image_embeds = image_embeds.reshape(-1, channels, height, width) |
|
|
else: |
|
|
batch, channels, height, width = image_embeds.shape |
|
|
num_frames = 1 |
|
|
|
|
|
image_embeds = self.proj(image_embeds) |
|
|
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) |
|
|
image_embeds = image_embeds.flatten(3).transpose(2, 3) |
|
|
image_embeds = image_embeds.flatten(1, 2) |
|
|
|
|
|
return image_embeds |
|
|
|
|
|
class AMDLayerNormZero(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
conditioning_dim: int, |
|
|
embedding_dim: int, |
|
|
elementwise_affine: bool = True, |
|
|
eps: float = 1e-5, |
|
|
bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embedding_dim |
|
|
self.silu = nn.SiLU() |
|
|
self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias) |
|
|
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1) |
|
|
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] |
|
|
encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :] |
|
|
return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :] |
|
|
|
|
|
class AMDLayerNormZero_OneVariable(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
conditioning_dim: int, |
|
|
embedding_dim: int, |
|
|
elementwise_affine: bool = True, |
|
|
eps: float = 1e-5, |
|
|
bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embedding_dim |
|
|
self.silu = nn.SiLU() |
|
|
self.linear = nn.Linear(conditioning_dim, 3 * embedding_dim, bias=bias) |
|
|
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, temb: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
shift, scale, gate = self.linear(self.silu(temb)).chunk(3, dim=1) |
|
|
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] |
|
|
return hidden_states, gate[:, None, :] |
|
|
|
|
|
class AMDLayerNormZero2Condition(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
conditioning_dim: int, |
|
|
embedding_dim: int, |
|
|
elementwise_affine: bool = True, |
|
|
eps: float = 1e-5, |
|
|
bias: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.embed_dim = embedding_dim |
|
|
self.silu = nn.SiLU() |
|
|
self.linear = nn.Linear(conditioning_dim, 9 * embedding_dim, bias=bias) |
|
|
self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine) |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: torch.Tensor, condition_states1: torch.Tensor,condition_states2:torch.Tensor, temb: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
shift, scale, gate, c1_shift, c1_scale, c1_gate,c2_shift, c2_scale, c2_gate = self.linear(self.silu(temb)).chunk(9, dim=1) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :] |
|
|
condition_states1 = self.norm(condition_states1) * (1 + c1_scale)[:, None, :] + c1_shift[:, None, :] |
|
|
condition_states2 = self.norm(condition_states2) * (1 + c2_scale)[:, None, :] + c2_shift[:, None, :] |
|
|
|
|
|
return hidden_states, condition_states1,condition_states2, gate[:, None, :], c1_gate[:, None, :],c2_gate[:, None, :] |
|
|
|
|
|
class AdaLayerNorm(nn.Module): |
|
|
r""" |
|
|
Norm layer modified to incorporate timestep embeddings. |
|
|
|
|
|
Parameters: |
|
|
embedding_dim (`int`): The size of each embedding vector. |
|
|
num_embeddings (`int`, *optional*): The size of the embeddings dictionary. |
|
|
output_dim (`int`, *optional*): |
|
|
norm_elementwise_affine (`bool`, defaults to `False): |
|
|
norm_eps (`bool`, defaults to `False`): |
|
|
chunk_dim (`int`, defaults to `0`): |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embedding_dim: int, |
|
|
num_embeddings: Optional[int] = None, |
|
|
output_dim: Optional[int] = None, |
|
|
norm_elementwise_affine: bool = False, |
|
|
norm_eps: float = 1e-5, |
|
|
chunk_dim: int = 0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.chunk_dim = chunk_dim |
|
|
output_dim = output_dim or embedding_dim * 2 |
|
|
|
|
|
if num_embeddings is not None: |
|
|
self.emb = nn.Embedding(num_embeddings, embedding_dim) |
|
|
else: |
|
|
self.emb = None |
|
|
|
|
|
self.silu = nn.SiLU() |
|
|
self.linear = nn.Linear(embedding_dim, output_dim) |
|
|
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
def forward( |
|
|
self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if self.emb is not None: |
|
|
temb = self.emb(timestep) |
|
|
|
|
|
temb = self.linear(self.silu(temb)) |
|
|
|
|
|
if self.chunk_dim == 1: |
|
|
|
|
|
|
|
|
shift, scale = temb.chunk(2, dim=1) |
|
|
shift = shift[:, None, :] |
|
|
scale = scale[:, None, :] |
|
|
else: |
|
|
scale, shift = temb.chunk(2, dim=0) |
|
|
|
|
|
x = self.norm(x) * (1 + scale) + shift |
|
|
return x |
|
|
|
|
|
|
|
|
class AMDTransformerBlock(nn.Module): |
|
|
r""" |
|
|
AMDTransformerBlock |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
************************* ****************** |
|
|
* encoder_hidden_states * * hidden_states * |
|
|
************************* ****************** |
|
|
""" |
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
|
hidden_states, encoder_hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
image_length = norm_encoder_hidden_states.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
if VIS_ATTEN_FLAG: |
|
|
global attention_maps |
|
|
attn_score = simple_attention_processor(self.attn1, norm_hidden_states) |
|
|
attention_maps.append(attn_score.detach().cpu()) |
|
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_output[:, image_length:] |
|
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, :image_length] |
|
|
|
|
|
|
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( |
|
|
hidden_states, encoder_hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + gate_ff * ff_output[:, image_length:] |
|
|
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :image_length] |
|
|
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
class BasicDiTBlock(nn.Module): |
|
|
r""" |
|
|
AMDTransformerBlock |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
norm_hidden_states, gate_msa = self.norm1( |
|
|
hidden_states, temb |
|
|
) |
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_output |
|
|
|
|
|
|
|
|
norm_hidden_states, gate_ff = self.norm2( |
|
|
hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + gate_ff * ff_output |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class AMDTransformerMotionBlock(nn.Module): |
|
|
r""" |
|
|
AMDTransformerBlock |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
norm_hidden_states, gate_msa = self.norm1( |
|
|
hidden_states, temb |
|
|
) |
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_output |
|
|
|
|
|
|
|
|
norm_hidden_states, gate_ff = self.norm2( |
|
|
hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + gate_ff * ff_output |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class TransformerBlock2Condition(nn.Module): |
|
|
r""" |
|
|
AMDTransformerBlock |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero2Condition(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero2Condition(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
condition_states1: torch.Tensor, |
|
|
condition_states2: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
************************* ****************** ******************** |
|
|
* hidden_states * *condition_states1* * condition_states2* |
|
|
************************* ****************** ******************** |
|
|
""" |
|
|
hidden_length = hidden_states.shape[1] |
|
|
condition1_length = condition_states1.shape[1] |
|
|
condition2_length = condition_states2.shape[1] |
|
|
|
|
|
norm_hidden_states, norm_condition_states1,norm_condition_states2, gate_msa, c_gate_msa1,c_gate_msa2 = self.norm1( |
|
|
hidden_states, condition_states1,condition_states2, temb |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) |
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_output[:,:hidden_length] |
|
|
condition_states1 = condition_states1 + c_gate_msa1 * attn_output[:, hidden_length:hidden_length+condition1_length] |
|
|
condition_states2 = condition_states2 + c_gate_msa2 * attn_output[:, hidden_length+condition1_length:] |
|
|
|
|
|
|
|
|
norm_hidden_states, norm_condition_states1,norm_condition_states2, gate_ff, c_gate_ff1,c_gate_ff2 = self.norm2( |
|
|
hidden_states, condition_states1,condition_states2, temb |
|
|
) |
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) |
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + gate_ff * ff_output[:, :hidden_length] |
|
|
condition_states1 = condition_states1 + c_gate_ff1 * ff_output[:, hidden_length:hidden_length+condition1_length] |
|
|
condition_states2 = condition_states2 + c_gate_ff2 * ff_output[:, hidden_length+condition1_length:] |
|
|
|
|
|
return hidden_states, condition_states1,condition_states2 |
|
|
|
|
|
class TransformerBlock2Condition_SimpleAdaLN(nn.Module): |
|
|
r""" |
|
|
TransformerBlock2Condition_SimpleAdaLN |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
self.norm1_condition1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
self.norm1_condition2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero_OneVariable(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
self.norm2_condition1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
self.norm2_condition2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
condition_states1: torch.Tensor, |
|
|
condition_states2: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
************************* ****************** ******************** |
|
|
* hidden_states * *condition_states1* * condition_states2* |
|
|
************************* ****************** ******************** |
|
|
""" |
|
|
hidden_length = hidden_states.shape[1] |
|
|
condition1_length = condition_states1.shape[1] |
|
|
condition2_length = condition_states2.shape[1] |
|
|
|
|
|
|
|
|
norm_hidden_states,gate = self.norm1(hidden_states, temb=temb) |
|
|
norm_condition_states1 = self.norm1_condition1(condition_states1) |
|
|
norm_condition_states2 = self.norm1_condition2(condition_states2) |
|
|
|
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) |
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states + gate * attn_output[:,:hidden_length] |
|
|
condition_states1 = condition_states1 + attn_output[:, hidden_length:hidden_length+condition1_length] |
|
|
condition_states2 = condition_states2 + attn_output[:, hidden_length+condition1_length:] |
|
|
|
|
|
|
|
|
norm_hidden_states,gate = self.norm2(hidden_states, temb=temb) |
|
|
norm_condition_states1 = self.norm2_condition1(condition_states1) |
|
|
norm_condition_states2 = self.norm2_condition2(condition_states2) |
|
|
|
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_hidden_states, norm_condition_states1,norm_condition_states2], dim=1) |
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + gate * ff_output[:, :hidden_length] |
|
|
condition_states1 = condition_states1 + ff_output[:, hidden_length:hidden_length+condition1_length] |
|
|
condition_states2 = condition_states2 + ff_output[:, hidden_length+condition1_length:] |
|
|
|
|
|
return hidden_states, condition_states1,condition_states2 |
|
|
|
|
|
|
|
|
class Any2MotionTransformerBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
motion_frames : int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.motion_frames = motion_frames |
|
|
|
|
|
|
|
|
self.norm1 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.attn2 = Attention( |
|
|
query_dim=dim, |
|
|
cross_attention_dim=dim, |
|
|
heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
dropout=dropout, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm3 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.attn3 = Attention( |
|
|
query_dim=dim, |
|
|
cross_attention_dim=dim, |
|
|
heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
dropout=dropout, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm4 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
refimg_states: torch.Tensor, |
|
|
extra_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
assert hidden_states.dim() == refimg_states.dim() and hidden_states.dim() == extra_states.dim() , f"hidden_states.dim():{hidden_states.dim()},refimg_states.dim():{refimg_states.dim()},extra_states.dim():{extra_states.dim()}" |
|
|
|
|
|
|
|
|
hidden_states = self.norm1(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states, '(b f) l d -> b (f l) d',f=self.motion_frames) |
|
|
attn_output = self.attn1(hidden_states, None) |
|
|
hidden_states = hidden_states + attn_output |
|
|
hidden_states = einops.rearrange(hidden_states, 'b (f l) d -> (b f) l d',f=self.motion_frames) |
|
|
|
|
|
|
|
|
hidden_states = self.norm2(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
attn_output = self.attn2(hidden_states, refimg_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + attn_output |
|
|
hidden_states = self.norm3(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
attn_output = self.attn3(hidden_states, extra_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + attn_output |
|
|
hidden_states = self.norm4(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
hidden_states = self.ff(hidden_states) + hidden_states |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class A2MCrossAttnBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
|
|
|
self.attn = Attention( |
|
|
query_dim=dim, |
|
|
cross_attention_dim=dim, |
|
|
heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
dropout=dropout, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
ref_motion_hidden_states: torch.Tensor, |
|
|
conditon_hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
N,FL,D = motion_hidden_states.shape |
|
|
N,TL,D = ref_motion_hidden_states.shape |
|
|
N,T_F,W,D = conditon_hidden_states.shape |
|
|
L = (FL + TL)//T_F |
|
|
|
|
|
|
|
|
if conditon_hidden_states.dim()==4 : |
|
|
conditon_hidden_states = einops.rearrange(conditon_hidden_states,'n f w d -> (n f) w d') |
|
|
|
|
|
|
|
|
|
|
|
norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
|
motion_hidden_states, ref_motion_hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
hidden_states = einops.rearrange(hidden_states,'n (f l) d -> (n f) l d',l=L) |
|
|
assert hidden_states.shape[0] == conditon_hidden_states.shape[0] ,f'hidden_states.shape {hidden_states.shape} ,audio_hidden_states.shape {audio_hidden_states.shape}' |
|
|
|
|
|
|
|
|
attn_output = self.attn(hidden_states, conditon_hidden_states) |
|
|
attn_output = einops.rearrange(attn_output,'(n f) l d -> n f l d',n=N).flatten(1,2) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,TL:] |
|
|
ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * attn_output[:,:TL] |
|
|
|
|
|
|
|
|
norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm2( |
|
|
motion_hidden_states, ref_motion_hidden_states, temb |
|
|
) |
|
|
hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
hidden_states = self.ff(hidden_states) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * hidden_states[:,TL:] |
|
|
ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * hidden_states[:,:TL] |
|
|
|
|
|
return motion_hidden_states,ref_motion_hidden_states |
|
|
|
|
|
|
|
|
class A2MMotionSelfAttnBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
|
|
|
self.attn = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
ref_motion_hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
N,FL,D = motion_hidden_states.shape |
|
|
N,TL,D = ref_motion_hidden_states.shape |
|
|
|
|
|
|
|
|
norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
|
motion_hidden_states, ref_motion_hidden_states, temb |
|
|
) |
|
|
hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
attn_output = self.attn(hidden_states, None) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,TL:] |
|
|
ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * attn_output[:,:TL] |
|
|
|
|
|
|
|
|
norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm2( |
|
|
motion_hidden_states, ref_motion_hidden_states, temb |
|
|
) |
|
|
hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
hidden_states = self.ff(hidden_states) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * hidden_states[:,TL:] |
|
|
ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * hidden_states[:,:TL] |
|
|
|
|
|
return motion_hidden_states,ref_motion_hidden_states |
|
|
|
|
|
class A2MMotionSelfAttnBlockDoubleRef(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
last_layer:bool = False, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
self.last_layer = last_layer |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
|
|
|
self.attn = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
self.attn2 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
self.norm3 = AMDLayerNormZero(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
if self.last_layer: |
|
|
self.norm4 = None |
|
|
else: |
|
|
self.norm4 = AMDLayerNormZero_OneVariable(conditioning_dim=time_embed_dim, |
|
|
embedding_dim=dim) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
if not self.last_layer: |
|
|
self.ff2 = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
else: |
|
|
self.ff2 = None |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
ref_motion_hidden_states: torch.Tensor, |
|
|
randomref_motion_hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
motion_token: int = 1, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
N,FL,D = motion_hidden_states.shape |
|
|
N,TL,D = ref_motion_hidden_states.shape |
|
|
NF,SL,D = randomref_motion_hidden_states.shape |
|
|
L = motion_token |
|
|
T = TL // L |
|
|
F = FL // L |
|
|
assert F == NF // N |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
|
motion_hidden_states, ref_motion_hidden_states, temb |
|
|
) |
|
|
hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
attn_output = self.attn(hidden_states, None) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,TL:] |
|
|
ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * attn_output[:,:TL] |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,"n (f l) d -> (n f) l d",l=L) |
|
|
flat_temb = temb.repeat_interleave(F,dim=0) |
|
|
norm_motion_hidden_states, norm_randomref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm2( |
|
|
motion_hidden_states, randomref_motion_hidden_states, flat_temb |
|
|
) |
|
|
hidden_states = torch.cat([norm_randomref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
attn_output = self.attn2(hidden_states, None) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * attn_output[:,SL:] |
|
|
randomref_motion_hidden_states = randomref_motion_hidden_states + enc_gate_msa * attn_output[:,:SL] |
|
|
|
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,"(n f) l d -> n (f l) d",n=N) |
|
|
norm_motion_hidden_states, norm_ref_motion_hidden_states, gate_msa, enc_gate_msa = self.norm3( |
|
|
motion_hidden_states, ref_motion_hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
hidden_states = torch.cat([norm_ref_motion_hidden_states,norm_motion_hidden_states],dim=1) |
|
|
hidden_states = self.ff(hidden_states) |
|
|
motion_hidden_states = motion_hidden_states + gate_msa * hidden_states[:,TL:] |
|
|
ref_motion_hidden_states = ref_motion_hidden_states + enc_gate_msa * hidden_states[:,:TL] |
|
|
|
|
|
|
|
|
if not self.last_layer: |
|
|
norm_randomref_motion_hidden_states,gate_msa_r = self.norm4(randomref_motion_hidden_states,flat_temb) |
|
|
norm_randomref_motion_hidden_states = self.ff2(norm_randomref_motion_hidden_states) |
|
|
randomref_motion_hidden_states = randomref_motion_hidden_states + gate_msa_r * norm_randomref_motion_hidden_states |
|
|
|
|
|
return motion_hidden_states,ref_motion_hidden_states,randomref_motion_hidden_states |
|
|
|
|
|
|
|
|
|
|
|
class AudioToImageShapeMlp(nn.Module): |
|
|
def __init__(self, |
|
|
audio_dim:int = 384, |
|
|
audio_block:int = 50, |
|
|
outchannel:int = 256, |
|
|
|
|
|
out_height:int = 4, |
|
|
out_width:int = 4, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
self.outchannel = outchannel |
|
|
self.out_height = out_height |
|
|
self.out_width = out_width |
|
|
|
|
|
|
|
|
outdim = outchannel * out_height * out_width |
|
|
self.mlp = Mlp(in_features=audio_dim*audio_block,hidden_features=outdim,out_features=outdim) |
|
|
|
|
|
def forward(self,audio_feature:torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
audio_feature (torch.Tensor): (N,F,M,C) |
|
|
|
|
|
Returns: |
|
|
audio_feature (torch.Tensor): (N,F,D) |
|
|
""" |
|
|
n,f,m,d = audio_feature.shape |
|
|
|
|
|
audio_feature = einops.rearrange(audio_feature,'n f m d -> n f (m d)') |
|
|
audio_feature = self.mlp(audio_feature) |
|
|
audio_feature = einops.rearrange(audio_feature,'n f (c h w) -> n f c h w',c=self.outchannel,h=self.out_height,w=self.out_width) |
|
|
|
|
|
return audio_feature |
|
|
|
|
|
class AudioFeatureMlp(nn.Module): |
|
|
def __init__(self, |
|
|
audio_dim:int = 384, |
|
|
audio_block:int = 50, |
|
|
hidden_dim:int = 128, |
|
|
outdim:int = 1024, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.mlp = Mlp(in_features=audio_dim*audio_block,hidden_features=outdim,out_features=outdim) |
|
|
|
|
|
|
|
|
def forward(self,audio_feature:torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
audio_feature (torch.Tensor): (N,F,M,C) |
|
|
|
|
|
Returns: |
|
|
audio_feature (torch.Tensor): (N,F,D) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
audio_feature = einops.rearrange(audio_feature,'n f m d -> n f (m d)') |
|
|
audio_feature = self.mlp(audio_feature) |
|
|
|
|
|
return audio_feature |
|
|
|
|
|
class AudioFeatureWindowMlp(nn.Module): |
|
|
def __init__(self, |
|
|
audio_dim:int = 384, |
|
|
audio_block:int = 50, |
|
|
intermediate_dim :int = 1024, |
|
|
window_size:int = 12, |
|
|
outdim:int = 768, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.window_size = window_size |
|
|
self.ff1 = nn.Linear(audio_dim*audio_block, intermediate_dim) |
|
|
self.ff2 = nn.Linear(intermediate_dim,intermediate_dim) |
|
|
self.ff3 = nn.Linear(intermediate_dim,window_size * outdim) |
|
|
self.norm = nn.LayerNorm(outdim) |
|
|
|
|
|
def forward(self,audio_feature:torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
audio_feature (torch.Tensor): (N,F,M,C) |
|
|
|
|
|
Returns: |
|
|
audio_feature (torch.Tensor): (N,F,W,D) |
|
|
""" |
|
|
n,f,m,d = audio_feature.shape |
|
|
|
|
|
audio_feature = einops.rearrange(audio_feature,'n f m d -> n f (m d)') |
|
|
audio_feature = torch.relu(self.ff1(audio_feature)) |
|
|
audio_feature = torch.relu(self.ff2(audio_feature)) |
|
|
audio_feature = torch.relu(self.ff3(audio_feature)) |
|
|
audio_feature = einops.rearrange(audio_feature,"n f (w d) -> n f w d",w= self.window_size) |
|
|
|
|
|
audio_feature = self.norm(audio_feature) |
|
|
|
|
|
return audio_feature |
|
|
|
|
|
|
|
|
class RefMotionRefImgeBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = True, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.norm1 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.attn2 = Attention( |
|
|
query_dim=dim, |
|
|
cross_attention_dim=dim, |
|
|
heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
dropout=dropout, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm3 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
|
|
|
self.attn3 = Attention( |
|
|
query_dim=dim, |
|
|
cross_attention_dim=dim, |
|
|
heads=num_attention_heads, |
|
|
dim_head=attention_head_dim, |
|
|
dropout=dropout, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm4 = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1 |
|
|
) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
refmotion_states: torch.Tensor, |
|
|
refimg_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
assert hidden_states.dim() == refimg_states.dim() and hidden_states.dim() == refmotion_states.dim() , f"hidden_states.dim():{hidden_states.dim()},refimg_states.dim():{refimg_states.dim()}" |
|
|
|
|
|
|
|
|
hidden_states = self.norm1(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
attn_output = self.attn1(hidden_states, None) |
|
|
hidden_states = hidden_states + attn_output |
|
|
|
|
|
|
|
|
hidden_states = self.norm2(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
attn_output = self.attn2(hidden_states, refmotion_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + attn_output |
|
|
hidden_states = self.norm3(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
attn_output = self.attn3(hidden_states, refimg_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + attn_output |
|
|
hidden_states = self.norm4(hidden_states, temb=temb) |
|
|
|
|
|
|
|
|
hidden_states = self.ff(hidden_states) + hidden_states |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
class Audio2Pose(nn.Module): |
|
|
def __init__(self, |
|
|
audio_dim:int = 384, |
|
|
audio_block:int = 50, |
|
|
|
|
|
motion_height:int = 4, |
|
|
motion_width:int = 4, |
|
|
motion_dim:int = 256, |
|
|
|
|
|
pose_width:int = 32, |
|
|
pose_height:int = 32, |
|
|
pose_dim:int = 4, |
|
|
num_frames:int = 15, |
|
|
|
|
|
**kwargs |
|
|
): |
|
|
super().__init__() |
|
|
self.num_frames = num_frames |
|
|
self.pw = pose_width |
|
|
self.ph = pose_height |
|
|
self.pc = pose_dim |
|
|
self.audio_encoder = AudioToImageShapeMlp( |
|
|
audio_dim=audio_dim, |
|
|
audio_block = audio_block, |
|
|
outchannel=motion_dim, |
|
|
out_height=motion_height, |
|
|
out_width=motion_width, |
|
|
) |
|
|
|
|
|
self.pose_predictor = Upsampler( |
|
|
in_channel=motion_dim, |
|
|
out_channel=pose_dim, |
|
|
block_out_channels=(motion_dim,128,64,32), |
|
|
) |
|
|
|
|
|
self.pose_downsample = DownEncoder(in_channel=pose_dim,block_out_channels=(32,64,128,motion_dim)) |
|
|
|
|
|
def forward(self,audio_feature:torch.Tensor,pose_gt:torch.Tensor): |
|
|
""" |
|
|
Args: |
|
|
audio_feature (torch.Tensor): (N,F,M,D) |
|
|
pose_gt (torch.Tensor): (N,F,C,H,W) |
|
|
|
|
|
Returns: |
|
|
pose_pred (torch.Tensor): (N,F,C,H,W), used for loss calculation |
|
|
pose_transform (torch.Tensor): (N,F,256,4,4), used for diffusion |
|
|
audio_hidden_state (torch.Tensor): (N,F,256,4,4), used for audio condition injection |
|
|
""" |
|
|
|
|
|
b,f,m,d = audio_feature.shape |
|
|
audio_hidden_state = self.audio_encoder(audio_feature) |
|
|
|
|
|
audio_hidden_state = einops.rearrange(audio_hidden_state,'n f c h w -> (n f) c h w') |
|
|
pose_pre = self.pose_predictor(audio_hidden_state) |
|
|
|
|
|
pose_gt = einops.rearrange(pose_gt,'n f c h w -> (n f) c h w') |
|
|
pose_gt_transform = self.pose_downsample(pose_gt) |
|
|
|
|
|
pose_pre = einops.rearrange(pose_pre,'(n f) c h w -> n f c h w',n=b) |
|
|
pose_gt_transform = einops.rearrange(pose_gt_transform,'(n f) c h w -> n f c h w',n=b) |
|
|
audio_hidden_state = einops.rearrange(audio_hidden_state,'(n f) c h w -> n f c h w',n=b) |
|
|
|
|
|
return pose_pre, pose_gt_transform, audio_hidden_state |
|
|
def prepare_extra(self,audio:torch.Tensor,pose:torch.Tensor): |
|
|
b = audio.shape[0] |
|
|
audio_hidden_state = self.audio_encoder(audio) |
|
|
pose_pred = self.pose_predictor(audio_hidden_state) |
|
|
pose_pred = self.pose_downsample(pose_pred) |
|
|
pose_pred = einops.rearrange(pose_pred,'(n f) c h w -> n f c h w',n=b) |
|
|
audio_hidden_state = einops.rearrange(audio_hidden_state,'(n f) c h w -> n f c h w',n=b) |
|
|
return audio_hidden_state, pose_pred |
|
|
|
|
|
class MotionTrensferBlock(nn.Module): |
|
|
r""" |
|
|
MotionTrensferBlock |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: int, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = AMDLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
encoder_hidden_states: torch.Tensor, |
|
|
temb: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
************************* ****************** |
|
|
* hidden_states* * encoder_hidden_states * |
|
|
************************* ****************** |
|
|
""" |
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1( |
|
|
hidden_states, encoder_hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
motion_length = norm_hidden_states.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_hidden_states, norm_encoder_hidden_states ], dim=1) |
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
|
|
|
hidden_states = hidden_states + gate_msa * attn_output[:, :motion_length] |
|
|
encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_output[:, motion_length:] |
|
|
|
|
|
|
|
|
norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2( |
|
|
hidden_states, encoder_hidden_states, temb |
|
|
) |
|
|
|
|
|
|
|
|
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1) |
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
|
|
|
hidden_states = hidden_states + gate_ff * ff_output[:, :motion_length] |
|
|
encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, motion_length:] |
|
|
|
|
|
return hidden_states, encoder_hidden_states |
|
|
|
|
|
|
|
|
|
|
|
class A2PTemporalSpatialBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.attn2 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm3 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
N,F,L,D = hidden_states.shape |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states,'n f l d -> (n l) f d') |
|
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
|
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
hidden_states = hidden_states + attn_output |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states,'(n l) f d -> (n f) l d',n=N,l=L) |
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
|
|
|
|
|
|
|
attn_output = self.attn2( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=None, |
|
|
) |
|
|
hidden_states = hidden_states + attn_output |
|
|
|
|
|
|
|
|
norm_hidden_states = self.norm3(hidden_states) |
|
|
|
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states,'(n f) l d -> n f l d',n=N) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class A2PCrossAudioBlock(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dim: int, |
|
|
num_attention_heads: int, |
|
|
attention_head_dim: int, |
|
|
time_embed_dim: Optional[int] = None, |
|
|
dropout: float = 0.0, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
attention_bias: bool = False, |
|
|
qk_norm: bool = True, |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
final_dropout: bool = True, |
|
|
ff_inner_dim: Optional[int] = None, |
|
|
ff_bias: bool = True, |
|
|
attention_out_bias: bool = True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.norm1 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.attn1 = Attention( |
|
|
query_dim=dim, |
|
|
dim_head=attention_head_dim, |
|
|
heads=num_attention_heads, |
|
|
qk_norm="layer_norm" if qk_norm else None, |
|
|
eps=1e-6, |
|
|
bias=attention_bias, |
|
|
out_bias=attention_out_bias, |
|
|
) |
|
|
|
|
|
|
|
|
self.norm2 = nn.LayerNorm(dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
self.ff = FeedForward( |
|
|
dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
final_dropout=final_dropout, |
|
|
inner_dim=ff_inner_dim, |
|
|
bias=ff_bias, |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
audio_hidden_states: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
N,F,L,D = hidden_states.shape |
|
|
N,F,W,D = audio_hidden_states.shape |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states,'n f l d -> (n f) l d') |
|
|
norm_hidden_states = self.norm1(hidden_states) |
|
|
audio_hidden_states = einops.rearrange(audio_hidden_states,'n f w d -> (n f) w d') |
|
|
|
|
|
|
|
|
attn_output = self.attn1( |
|
|
hidden_states=norm_hidden_states, |
|
|
encoder_hidden_states=audio_hidden_states, |
|
|
) |
|
|
hidden_states = hidden_states + attn_output |
|
|
|
|
|
|
|
|
|
|
|
norm_hidden_states = self.norm2(hidden_states) |
|
|
|
|
|
|
|
|
ff_output = self.ff(norm_hidden_states) |
|
|
hidden_states = hidden_states + ff_output |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states,'(n f) l d -> n f l d',n=N,f=F) |
|
|
|
|
|
return hidden_states |
|
|
|