semo / model /transformer.py
HappyP4nda's picture
Upload folder using huggingface_hub
bd546bf verified
import torch
import torch.nn as nn
from typing import Optional, Union, Dict, Any
import einops
from diffusers.utils import is_torch_version
from transformers import AutoModel
from diffusers import CogVideoXPipeline
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler
from diffusers.models.attention import Attention,FeedForward
from .modules import (BasicTransformerBlock,
PatchEmbed,
AMDTransformerBlock,
AdaLayerNorm,
TransformerBlock2Condition,
TransformerBlock2Condition_SimpleAdaLN,
A2MMotionSelfAttnBlock,
A2MCrossAttnBlock,
A2PTemporalSpatialBlock,
A2PCrossAudioBlock,
AMDTransformerMotionBlock,
BasicDiTBlock,
A2MMotionSelfAttnBlockDoubleRef)
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed,get_2d_sincos_pos_embed,get_1d_sincos_pos_embed_from_grid
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from diffusers.models.modeling_utils import ModelMixin
# ------------ motion encoder ---------------
class MotionEncoderLearnTokenTransformer(nn.Module):
r"""
Motion Encoder With Learnable Token
"""
def __init__(
self,
# ----- img
img_height: int = 32,
img_width: int = 32,
img_inchannel: int = 4,
img_patch_size: int = 2,
# ----- motion
motion_token_num:int = 12,
motion_channel:int = 128,
need_norm_out :bool = True,
# ----- attention
num_attention_heads: int = 12,
attention_head_dim: int = 64,
num_layers: int = 8,
freq_shift: int = 0,
dropout: float = 0.0,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# setting
hidden_dim = num_attention_heads * attention_head_dim
iph = img_height // img_patch_size
ipw = img_width // img_patch_size
itl = iph * ipw
self.img_token_len = itl
# motion token
INIT_CONST = 0.02
self.motion_token = nn.Parameter(torch.randn(1, motion_token_num, motion_channel) * INIT_CONST)
self.motion_embed = nn.Linear(motion_channel,hidden_dim)
# img embedding
self.patch_embed = PatchEmbed(img_patch_size,img_inchannel,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False) # (1,itl,hidden_dim)
# transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
# Output blocks
self.norm_final = nn.LayerNorm(hidden_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
self.proj_out = nn.Linear(hidden_dim, motion_channel)
self.need_norm_out = need_norm_out
if self.need_norm_out:
self.norm_out = nn.LayerNorm(motion_channel, eps=norm_eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self,
img_hidden_states: torch.Tensor, # N,T,C,H,W | ref_frame + frames
mask_ratio = None,
):
N, T, C, H, W = img_hidden_states.shape
# motion token
motion_token = self.motion_embed(self.motion_token) # (1,motion_token_num,D)
motion_token = motion_token.repeat(N*T,1,1) # (NT,motion_token_num,D)
# img token
img_hidden_states = einops.rearrange(img_hidden_states, 'n t c h w -> (n t) c h w')
img_hidden_states = self.patch_embed(img_hidden_states) # [NT,hw,D]
assert self.img_token_len == img_hidden_states.shape[1] , 'img_token_len should be equal!'
# img position encoding
pos_embeds = self.pos_embedding[:, :self.img_token_len] # [1,hw, D]
img_hidden_states = img_hidden_states + pos_embeds
img_hidden_states = self.embedding_dropout(img_hidden_states)
# random mask
if mask_ratio is not None:
img_hidden_states,_,_ = self.random_masking(img_hidden_states,mask_ratio)
# cat
hidden_states = torch.cat([motion_token,img_hidden_states],dim=1) # [NT,token_m + token_i ,D]
# Transformer blocks
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
)
# Final block
motion_token = hidden_states[:, :motion_token.shape[1],:] # [NT,motion_token_num,D]
motion_token = self.norm_final(motion_token) # [NT,motion_token_num ,D]
motion_token = self.proj_out(motion_token) # [NT,motion_token_num ,motion_channel]
if self.need_norm_out:
motion_token = self.norm_out(motion_token)
# Unpatchify
motion_token = einops.rearrange(motion_token, '(n t) l d -> n t l d',n=N) # [N,T,motion_token_num,motion_channel]
return motion_token # [N,T,motion_token_num,motion_channel]
def random_masking(self, x, mask_ratio):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio)) # L*0.25
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample
# [0.3,0.5,0.9,0.1,0.6] -> [3,0,1,4,2] from small to large
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove , Smallest row front
# [3,0,1,4,2] -> [1,2,4,0,3] restore the original order
ids_restore = torch.argsort(ids_shuffle, dim=1) # (N,L)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep] # (N, len_keep) e.g. [3,0,1,4,2] -> [3,0,1]
# index (N,len_keep,D) x_masked (N,len_keep,D)
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore) # (N,L) 0 is keep, 1 is remove
return x_masked, mask, ids_restore # (N,len_keep,D), (N,L), (N,L)
# ------------ motion transformer ---------------
class MotionTransformer(nn.Module):
"""
Motion
"""
_supports_gradient_checkpointing = True
def __init__(
self,
motion_token_num : int = 4,
motion_token_channel : int = 128,
motion_frames : int = 128,
attention_head_dim : int = 64,
num_attention_heads: int = 16,
num_layers: int = 8,
freq_shift: int = 0,
dropout: float = 0.0,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = motion_token_channel
self.motion_token_length = motion_token_num * motion_frames
# 1. Patch embedding
self.embed = nn.Linear(motion_token_channel,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2. 1D positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_token_length))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 3. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine)
# 4. Output blocks
self.proj_out = nn.Linear(hidden_dim,motion_token_channel)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor, # N,F,L,D
):
N,F,L,D = hidden_states.shape
# 2. Patch embedding
hidden_states = self.embed(hidden_states) # N,F,L,D
# 3. Position embedding
hidden_states = hidden_states.flatten(1,2) + self.motion_pos_embedding[:,:F*L,:] # N,FL,D
# 5. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
)
hidden_states = self.norm_final(hidden_states) # N,FL,D
# 6. Final block
hidden_states = self.proj_out(hidden_states) # N,FL,D
# 7. Unpatchify
hidden_states = einops.rearrange(hidden_states,'n (f l) d -> n f l d',f=F)
return hidden_states # N,F,L,D
# ------------ diffusion ---------------
class AMDReconstructTransformerModel(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# Split Token
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_source_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L + 2
# Patch embedding
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # [N,S,D]
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # [N,S,D]
image_hidden_states = self.image_patch_embed(image_hidden_states)
# cat
source_token = self.source_token.repeat(N, 1, 1)
target_token = self.target_token.repeat(N, 1, 1)
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # [N,2S+2,D]
# Position embedding
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length]
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length]
self.embedding_dropout(motion_hidden_states)
# Transformer blocks
hidden_states = torch.cat([image_hidden_states,motion_hidden_states],dim=1)
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
)
image_hidden_states = self.norm_final(hidden_states[:,:image_hidden_states.shape[1],:])
# 6. Final block
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDReconstructTransformerModelSpatial(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
motion_target_num_frame : int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
self.target_frame = motion_target_num_frame
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 1D img temporal position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame))
img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False)
# Split Token
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.spatial_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_source_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L + 2
n = N // self.target_frame
t = self.target_frame
l = L
d = Cm
# Patch embedding
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # [N,S,D]
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # [N,S,D]
image_hidden_states = self.image_patch_embed(image_hidden_states) #
# cat
source_token = self.source_token.repeat(N, 1, 1)
target_token = self.target_token.repeat(N, 1, 1)
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # [N,2S+2,D]
# Position embedding
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] # nt,2l+2,d
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] # nt,s,d
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n)
image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,:t] # ns,t,d
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) # nt,s,d
self.embedding_dropout(motion_hidden_states)
# Transformer blocks
img_length = image_hidden_states.shape[1]
motion_length = motion_hidden_states.shape[1]
for block,s_block in zip(self.transformer_blocks,self.spatial_blocks):
hidden_states = torch.cat([image_hidden_states,motion_hidden_states],dim=1)
hidden_states = block(
hidden_states=hidden_states,
) # nt l+s d
image_hidden_states = hidden_states[:,:img_length,:] # nt l d
motion_hidden_states = hidden_states[:,img_length:,:]# nt s d
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n)
image_hidden_states = s_block(
hidden_states = image_hidden_states,
) # nl t d
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) # nt l d
# 6. Final block
image_hidden_states = self.norm_final(image_hidden_states)
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDReconstructSplitTransformerModel(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.zi_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True)
self.zt_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# Split Token
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_source_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L + 2
# Patch embedding
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # [N,S,D]
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # [N,S,D]
zi_image_hidden_states = self.zi_image_patch_embed(image_hidden_states[:,:Ci//2,:,:])
zt_image_hidden_states = self.zt_image_patch_embed(image_hidden_states[:,Ci//2:,:,:])
# cat
source_token = self.source_token.repeat(N, 1, 1)
target_token = self.target_token.repeat(N, 1, 1)
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # [N,2S+2,D]
# Position embedding
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length]
zi_image_hidden_states = zi_image_hidden_states + self.pos_embedding[:, :image_seq_length]
zt_image_hidden_states = zt_image_hidden_states + self.pos_embedding[:, :image_seq_length]
self.embedding_dropout(motion_hidden_states)
# Transformer blocks
hidden_states = torch.cat([zt_image_hidden_states,zi_image_hidden_states,motion_hidden_states],dim=1)
for i, block in enumerate(self.transformer_blocks):
hidden_states = block(
hidden_states=hidden_states,
)
image_hidden_states = self.norm_final(hidden_states[:,:zt_image_hidden_states.shape[1],:])
# 6. Final block
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDDiffusionTransformerModel(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# Split Token
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
# Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
AMDTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * hidden_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_source_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L + 2
# Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# Patch embedding
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # [N,S,D]
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # [N,S,D]
image_hidden_states = self.image_patch_embed(image_hidden_states)
# cat
source_token = self.source_token.repeat(N, 1, 1)
target_token = self.target_token.repeat(N, 1, 1)
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # [N,2S+2,D]
# Position embedding
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length]
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length]
self.embedding_dropout(motion_hidden_states)
# Transformer blocks
for i, block in enumerate(self.transformer_blocks):
motion_hidden_states, image_hidden_states = block(
hidden_states=motion_hidden_states,
encoder_hidden_states=image_hidden_states,
temb=emb,
)
image_hidden_states = self.norm_final(image_hidden_states)
# 6. Final block
image_hidden_states = self.norm_out(image_hidden_states, temb=emb)
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDDiffusionTransformerModelDualStream(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
motion_target_num_frame : int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
self.target_frame = motion_target_num_frame
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# learnable token
INIT_CONST = 0.02
self.source_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * INIT_CONST)
self.target_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * INIT_CONST)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2*motion_token_num+2))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2*self.target_frame * (motion_token_num+1)))
motion_temporal_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_temporal_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_temporal_embedding",motion_temporal_embedding,persistent=False)
# Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
AMDTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.motion_blocks = nn.ModuleList(
[
AMDTransformerMotionBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * hidden_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (nt,l,d)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_target_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L
n = N // self.target_frame
t = self.target_frame
l = L
d = Cm
# Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb_m = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t)# n,t,d
emb_m = emb_m[:,0,:] # n,d
# Patch embedding
image_hidden_states = self.image_patch_embed(image_hidden_states) # nt,s,d
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # nt,l,d
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # nt,l,d
source_token = self.source_token.repeat(n*t,1,1) # nt,l,d
target_token = self.target_token.repeat(n*t,1,1) # nt,l,d
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # nt,2+2l,d
# motion position encoding1
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :2*l+2] # nt,2+2l,d
# motion position encoding2
motion_hidden_states = einops.rearrange(motion_hidden_states,'(n t) l d -> n (t l) d',n=n) # n,t(2l+2),d
motion_hidden_states = motion_hidden_states + self.motion_temporal_embedding[:,:t*(2*l+2)] # n,t(2l+2),d
# img Position embedding
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length]
# Transformer blocks
for block,m_block in zip(self.transformer_blocks,self.motion_blocks):
# motion temporal block
motion_hidden_states = m_block(
hidden_states = motion_hidden_states,
temb = emb_m,
) # n,t(2l+2),d
# transform for block
motion_hidden_states = einops.rearrange(motion_hidden_states,'n (t l) d -> (n t) l d',t=t)
# img block
motion_hidden_states, image_hidden_states = block(
hidden_states=motion_hidden_states,
encoder_hidden_states=image_hidden_states,
temb=emb,
)
motion_hidden_states = einops.rearrange(motion_hidden_states,'(n t) l d -> n (t l) d',t=t)
image_hidden_states = self.norm_final(image_hidden_states)
# 6. Final block
image_hidden_states = self.norm_out(image_hidden_states, temb=emb)
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDDiffusionTransformerModelImgSpatial(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
motion_target_num_frame : int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
self.target_frame = motion_target_num_frame
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 1D img temporal position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame))
img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False)
# Split Token
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
# Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
AMDTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.spatial_blocks = nn.ModuleList(
[
BasicDiTBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * hidden_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_source_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L + 2
n = N // self.target_frame
t = self.target_frame
l = L
d = Cm
# Patch embedding
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # [B,S,D]
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # [B,S,D]
image_hidden_states = self.image_patch_embed(image_hidden_states)
# Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb_s = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t)# n,t,d
emb_s = emb_s[:,:1,:].repeat(1,image_hidden_states.shape[1],1) # n,s,d
emb_s = emb_s.flatten(0,1) # ns,d
# cat
source_token = self.source_token.repeat(N, 1, 1)
target_token = self.target_token.repeat(N, 1, 1)
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # [B,2S+2,D]
# Position embedding
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] # nt,2l+2,d
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] # nt,s,d
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n)
image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,:t] # ns,t,d
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) # nt,s,d
self.embedding_dropout(motion_hidden_states)
# Transformer blocks
for block,s_block in zip(self.transformer_blocks,self.spatial_blocks):
motion_hidden_states, image_hidden_states = block(
hidden_states=motion_hidden_states,
encoder_hidden_states=image_hidden_states,
temb=emb,
)
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n)
image_hidden_states = s_block(
hidden_states = image_hidden_states,
temb = emb_s,
)
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n)
image_hidden_states = self.norm_final(image_hidden_states)
# 6. Final block
image_hidden_states = self.norm_out(image_hidden_states, temb=emb)
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDDiffusionTransformerModelImgSpatialDoubleRef(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 20,
attention_head_dim: int = 64,
out_channels: Optional[int] = 4,
num_layers: int = 12,
# ----- img
image_width: int = 32,
image_height: int = 32,
image_patch_size: int = 2,
image_in_channels: Optional[int] = 4,
# ----- motion
motion_token_num:int = 12,
motion_in_channels: Optional[int] = 128,
motion_target_num_frame : int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
self.image_patch_size = image_patch_size
self.out_channels = out_channels
self.target_frame = motion_target_num_frame
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels*2,embed_dim=hidden_dim, bias=True)
self.lastref_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 2D positional embeddings
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 1D position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 1D img temporal position encoding
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame + 1))
img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False)
# Split Token
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True)
# Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
AMDTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.spatial_blocks = nn.ModuleList(
[
BasicDiTBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * hidden_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_source_hidden_states: torch.Tensor,
motion_target_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
randomref_image_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
randomref_image_hidden_states : (b,c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,L,Cm = motion_source_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = 2*L + 2
n = N // self.target_frame
t = self.target_frame
l = L
d = Cm
# img
lastref_image_hidden_states = image_hidden_states[:,:Ci//2,:,:] # (b,c,H,W)
noised_image_hidden_states = image_hidden_states[:,Ci//2:,:,:] # (b,c,H,W)
randomref_image_hidden_states = randomref_image_hidden_states # (b,c,H,W)
# Patch embedding
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) # [B,L,D]
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) # [B,L,D]
image_hidden_states = self.image_patch_embed(torch.cat([randomref_image_hidden_states,noised_image_hidden_states],dim=1)) # (b,s,d)
lastref_image_hidden_states = self.lastref_image_patch_embed(lastref_image_hidden_states) # (b,s,d)
# Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
emb_s = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t)# n,t,d
emb_s = emb_s[:,:1,:].repeat(1,image_hidden_states.shape[1],1) # n,s,d
emb_s = emb_s.flatten(0,1) # ns,d
# cat
source_token = self.source_token.repeat(N, 1, 1)
target_token = self.target_token.repeat(N, 1, 1)
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) # [B,2S+2,D]
# Position embedding
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] # nt,2l+2,d
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] # nt,s,d
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n)
image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,1:t+1] # ns,t,d
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) # nt,s,d
lastref_image_hidden_states = lastref_image_hidden_states + self.pos_embedding[:, :image_seq_length] # nt,s,d
lastref_image_hidden_states = einops.rearrange(lastref_image_hidden_states,'(n t) s d -> (n s) t d',n=n) # ns,t,d
lastref_image_hidden_states = lastref_image_hidden_states[:,:1,:]# ns,1,d
lastref_image_hidden_states = lastref_image_hidden_states + self.img_temporal_embedding[:,:1] # ns,1,d
self.embedding_dropout(motion_hidden_states)
# Transformer blocks
for block,s_block in zip(self.transformer_blocks,self.spatial_blocks):
motion_hidden_states, image_hidden_states = block(
hidden_states=motion_hidden_states,
encoder_hidden_states=image_hidden_states,
temb=emb,
)
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) # ns t d
s_block_input_hidden_states = torch.cat([lastref_image_hidden_states,image_hidden_states],dim=1) # ns 1+t d
s_block_output = s_block(
hidden_states = s_block_input_hidden_states,
temb = emb_s,
)
lastref_image_hidden_states = s_block_output[:,:1,:]
image_hidden_states = s_block_output[:,1:,:]
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n)
image_hidden_states = self.norm_final(image_hidden_states)
# 6. Final block
image_hidden_states = self.norm_out(image_hidden_states, temb=emb)
image_hidden_states = self.proj_out(image_hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class AMDDiffusionTransformerModelSplitInput(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
image_in_channels: Optional[int] = 4,
motion_in_channels: Optional[int] = 16,
out_channels: Optional[int] = 4,
num_layers: int = 16,
image_width: int = 64,
image_height: int = 64,
motion_width: int = 8,
motion_height: int = 8,
image_patch_size: int = 2,
motion_patch_size: int = 1,
motion_frames: int = 15,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = 2*iph * ipw # image token length
mph = motion_height // motion_patch_size
mpw = motion_width // motion_patch_size
mtl = mph * mpw * motion_frames # motion token num
self.max_seq_length = itl + mtl
self.image_patch_size = image_patch_size
self.motion_patch_size = motion_patch_size
self.out_channels = out_channels
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.zi_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True)
self.zt_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = PatchEmbed(patch_size=motion_patch_size,in_channels=motion_in_channels,embed_dim=hidden_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# # 3. 2D&3D positional embeddings
# image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
# image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
# pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
# pos_embedding.data[:, :itl].copy_(image_pos_embedding)
# self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# 3. 3D positional embeddings
spatial_pos_embedding = get_3d_sincos_pos_embed(
hidden_dim,
(iph, ipw),
2,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) # [T*H*W, D]
# pos_embedding = spatial_pos_embedding.unsqueeze(0)# [1,T*H*W, D]
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False)
pos_embedding.data.copy_(spatial_pos_embedding)
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
# # 4. Token <pad>
# self.pad = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
AMDTransformerBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * hidden_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
image_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (b,d,h,w)
image_hidden_states : (b,2c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = image_hidden_states.shape
N,Cm,Hm,Wm = motion_hidden_states.shape
image_seq_length = 2 * Hi * Wi // (self.image_patch_size**2)
motion_seq_length = Hm * Wm // (self.motion_patch_size**2)
zi = image_hidden_states[:,:Ci//2]
zt = image_hidden_states[:,Ci//2:]
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,S,D]
zi_hidden_states = self.zi_patch_embed(zi)
zt_hidden_states = self.zt_patch_embed(zt)
image_hidden_states = torch.cat((zi_hidden_states,zt_hidden_states),dim=1)
assert image_seq_length == image_hidden_states.shape[1] , f"image_seq_length : {image_seq_length} != image_hidden_states.shape[1] : {image_hidden_states.shape[1]}"
# 3. Position embedding
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length]
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
motion_hidden_states, image_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
motion_hidden_states,
image_hidden_states,
emb,
**ckpt_kwargs,
)
else:
motion_hidden_states, image_hidden_states = block(
hidden_states=motion_hidden_states,
encoder_hidden_states=image_hidden_states,
temb=emb,
)
pre = image_hidden_states[:,image_seq_length//2:] # (N,S,D)
pre = self.norm_final(pre)
# 6. Final block
pre = self.norm_out(pre, temb=emb)
pre = self.proj_out(pre)
# 7. Unpatchify
p = self.image_patch_size
output = pre.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
class DiffusionTransformerModel2Condition(nn.Module):
"""
Diffusion Transformer
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 64,
image_in_channels: Optional[int] = 4,
motion_in_channels: Optional[int] = 16,
out_channels: Optional[int] = 4,
num_layers: int = 16,
image_width: int = 32,
image_height: int = 32,
motion_width: int = 8,
motion_height: int = 8,
image_patch_size: int = 2,
motion_patch_size: int = 1,
motion_frames: int = 15,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
"""
Traning:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
Inference:
Z(N,1,C,H,W)
Motion(N,k,d,h,w)
"""
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
iph = image_height // image_patch_size
ipw = image_width // image_patch_size
itl = iph * ipw # image token length
mph = motion_height // motion_patch_size
mpw = motion_width // motion_patch_size
mtl = mph * mpw * motion_frames # motion token num
self.max_seq_length = 2 * itl + mtl
self.image_patch_size = image_patch_size
self.motion_patch_size = motion_patch_size
self.out_channels = out_channels
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.refimg_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True)
self.motion_patch_embed = PatchEmbed(patch_size=motion_patch_size,in_channels=motion_in_channels,embed_dim=hidden_dim, bias=True)
# self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 3D positional embeddings for img
spatial_pos_embedding = get_3d_sincos_pos_embed(
hidden_dim,
(iph, iph),
2,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) # [T*H*W, D]
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False)
pos_embedding.data.copy_(spatial_pos_embedding)
self.register_buffer("img_pos_embedding", pos_embedding, persistent=False)
# 3D positional embeddings for motion
spatial_pos_embedding = get_3d_sincos_pos_embed(
hidden_dim,
(mph, mph),
motion_frames,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) # [T*H*W, D]
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False)
pos_embedding.data.copy_(spatial_pos_embedding)
self.register_buffer("motion_pos_embedding", pos_embedding, persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock2Condition(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=2 * hidden_dim,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
refimg_hidden_states: torch.Tensor, # condition1
motion_hidden_states: torch.Tensor, # condition2
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
hidden_states : (b,c,H,W)
motion_hidden_states : (b,d,h,w)
refimg_hidden_states : (b,c,H,W)
time_step : (b,)
"""
N,Ci,Hi,Wi = hidden_states.shape
N,Cm,Hm,Wm = motion_hidden_states.shape
image_seq_length = Hi * Wi // (self.image_patch_size**2)
motion_seq_length = Hm * Wm // (self.motion_patch_size**2)
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond)
# 2. Patch embedding
hidden_states = self.image_patch_embed(hidden_states) # [N,T,D]
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,S,D] condition
refimg_hidden_states = self.refimg_patch_embed(refimg_hidden_states) # [N,S,D] condition
# 3. Position embedding
hidden_states = hidden_states + self.img_pos_embedding[:,:image_seq_length,:]
refimg_hidden_states = refimg_hidden_states + self.img_pos_embedding[:, image_seq_length:,:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length,:]
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
motion_hidden_states, refimg_hidden_states,motion_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
refimg_hidden_states,
motion_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, refimg_hidden_states,motion_hidden_states = block(
hidden_states,
refimg_hidden_states,
motion_hidden_states,
emb,
)
hidden_states = self.norm_final(hidden_states)
# 6. Final block
hidden_states = self.norm_out(hidden_states, temb=emb)
hidden_states = self.proj_out(hidden_states)
# 7. Unpatchify
p = self.image_patch_size
output = hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) # [N,C,H,W]
return output
# ------------ A2M ---------------
# Audio + mition_ref + motion_t
class AudioMitionref_LearnableToken(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
motion_num_token: int = 12,
motion_inchannel: int = 128,
motion_frames: int = 128,
extra_in_channels: Optional[int] = 768,
out_channels: Optional[int] = 128,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = out_channels
self.motion_frames = motion_frames
self.motion_num_token = motion_num_token
self.motion_num_tokens = motion_num_token * motion_frames
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.extra_embed = nn.Linear(extra_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 1d positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_token + self.motion_num_tokens))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_frames))
audio_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
audio_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("audio_pos_embedding",audio_pos_embedding,persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock2Condition(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=hidden_dim*2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
refmotion_hidden_states: torch.Tensor,
extra_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
motion_hidden_states : (N,F,L,D)
refmotion_hidden_states : (N,L,D)
pose_hidden_states: (N,C,H,W)
extra_hidden_states : (N,F,D) 这个需要提前做 position encoding if needed ,前面两个固定在这里做position encoding
"""
assert motion_hidden_states.shape[1] == extra_hidden_states.shape[1],f"motion {motion_hidden_states.shape} ,audio{extra_hidden_states.shape}"
N,L,D = refmotion_hidden_states.shape
N,F,L,D = motion_hidden_states.shape
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) # (batch,time_embed_dim) (8,512)
# 2. Patch embedding
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d')
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,FL,D]
ref_motion_hidden_states = self.refmotion_patch_embed(refmotion_hidden_states) # [N,L,D]
extra_hidden_states = self.extra_embed(extra_hidden_states) # [N,F,D]
# 3. Position embedding
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:L,:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,L:,:]
extra_hidden_states = extra_hidden_states + self.audio_pos_embedding
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
motion_hidden_states,
ref_motion_hidden_states,
extra_hidden_states,
emb,
**ckpt_kwargs,
)
else:
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = block(
motion_hidden_states,
ref_motion_hidden_states,
extra_hidden_states,
temb=emb,
)
motion_hidden_states = self.norm_final(motion_hidden_states)
# 6. Final block
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb)
motion_hidden_states = self.proj_out(motion_hidden_states) # [N,L1,D]
# 7. Unpatchify
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F)
return output # N,F,L,D
# Audio + mition_ref + motion_t
class AudioMitionref_LearnableToken_SimpleAdaLN(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
motion_num_token: int = 12,
motion_inchannel: int = 128,
motion_frames: int = 128,
extra_in_channels: Optional[int] = 768,
out_channels: Optional[int] = 128,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = out_channels
self.motion_frames = motion_frames
self.motion_num_token = motion_num_token
self.motion_num_tokens = motion_num_token * motion_frames
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.extra_embed = nn.Linear(extra_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 1d positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_token + self.motion_num_tokens))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_frames))
audio_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
audio_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("audio_pos_embedding",audio_pos_embedding,persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.transformer_blocks = nn.ModuleList(
[
TransformerBlock2Condition_SimpleAdaLN(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=hidden_dim*2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
refmotion_hidden_states: torch.Tensor,
extra_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
motion_hidden_states : (N,F,L,D)
refmotion_hidden_states : (N,L,D)
pose_hidden_states: (N,C,H,W)
extra_hidden_states : (N,F,D) 这个需要提前做 position encoding if needed ,前面两个固定在这里做position encoding
"""
assert motion_hidden_states.shape[1] == extra_hidden_states.shape[1],f"motion {motion_hidden_states.shape} ,audio{extra_hidden_states.shape}"
N,L,D = refmotion_hidden_states.shape
N,F,L,D = motion_hidden_states.shape
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) # (batch,time_embed_dim) (8,512)
# 2. Patch embedding
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d')
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,FL,D]
ref_motion_hidden_states = self.refmotion_patch_embed(refmotion_hidden_states) # [N,L,D]
extra_hidden_states = self.extra_embed(extra_hidden_states) # [N,F,D]
# 3. Position embedding
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:L,:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,L:,:]
extra_hidden_states = extra_hidden_states + self.audio_pos_embedding
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
motion_hidden_states,
ref_motion_hidden_states,
extra_hidden_states,
emb,
**ckpt_kwargs,
)
else:
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = block(
motion_hidden_states,
ref_motion_hidden_states,
extra_hidden_states,
temb=emb,
)
motion_hidden_states = self.norm_final(motion_hidden_states)
# 6. Final block
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb)
motion_hidden_states = self.proj_out(motion_hidden_states) # [N,L1,D]
# 7. Unpatchify
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F)
return output # N,F,L,D
# Audio
class A2MTransformer_CrossAttn_Audio(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
motion_num_token: int = 12,
motion_inchannel: int = 128,
motion_frames: int = 128,
audio_window : Optional[int] = 12,
audio_in_channels: Optional[int] = 128,
out_channels: Optional[int] = 128,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = out_channels
self.motion_frames = motion_frames
self.motion_num_token = motion_num_token
self.motion_num_tokens = motion_num_token * motion_frames
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 1d positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.motion_blocks = nn.ModuleList(
[
A2MMotionSelfAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.audio_blocks = nn.ModuleList(
[
A2MCrossAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=hidden_dim*2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
refmotion_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (N,F,L,D)
refmotion_hidden_states : (N,T,L,D)
audio_hidden_states : (N,T+F,W,D)
"""
N,T,L,D = refmotion_hidden_states.shape
N,F,L,D = motion_hidden_states.shape
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) # (batch,time_embed_dim) (8,512)
# 2. Patch embedding
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d')
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,FL,D]
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d')
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) # [N,TL,D]
audio_hidden_states = self.audio_embed(audio_hidden_states) # [N,T+F,W,D]
# 3. Position embedding
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:]
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for motion_block,audio_block in zip(self.motion_blocks,self.audio_blocks):
motion_hidden_states,ref_motion_hidden_states = motion_block(
motion_hidden_states,
ref_motion_hidden_states,
temb=emb,
)
motion_hidden_states,ref_motion_hidden_states = audio_block(
motion_hidden_states,
ref_motion_hidden_states,
audio_hidden_states,
temb=emb,
)
motion_hidden_states = self.norm_final(motion_hidden_states)
# 6. Final block
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb)
motion_hidden_states = self.proj_out(motion_hidden_states) # [N,L1,D]
# 7. Unpatchify
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F)
return output # N,F,L,D
class A2MTransformer_CrossAttn_Audio_DoubleRef(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
motion_num_token: int = 12,
motion_inchannel: int = 128,
motion_frames: int = 128,
audio_window : Optional[int] = 12,
audio_in_channels: Optional[int] = 128,
out_channels: Optional[int] = 128,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = out_channels
self.motion_frames = motion_frames
self.motion_num_token = motion_num_token
self.motion_num_tokens = motion_num_token * motion_frames
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 1d positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.motion_blocks = nn.ModuleList(
[
A2MMotionSelfAttnBlockDoubleRef(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
last_layer = (i==num_layers-1),
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for i in range(num_layers)
]
)
self.audio_blocks = nn.ModuleList(
[
A2MCrossAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=hidden_dim*2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
refmotion_hidden_states: torch.Tensor,
randomrefmotion_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (N,F,L,D)
refmotion_hidden_states : (N,T,L,D)
randomrefmotion_hidden_states : (N,S,L,D)
audio_hidden_states : (N,T+F,W,D)
"""
N,T,L,D = refmotion_hidden_states.shape
N,F,L,D = motion_hidden_states.shape
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) # (batch,time_embed_dim) (8,512)
# 2. Patch embedding
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d')
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,FL,D]
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d')
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) # [N,TL,D]
randomref_motion_hidden_states = einops.rearrange(randomrefmotion_hidden_states,'n s l d -> n (s l) d')
randomref_motion_hidden_states = self.refmotion_patch_embed(randomref_motion_hidden_states) # [N,SL,D]
audio_hidden_states = self.audio_embed(audio_hidden_states) # [N,T+F,W,D]
# 3. Position embedding
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:]
randomref_motion_hidden_states = randomref_motion_hidden_states.repeat_interleave(F, dim=0) # NF,SL,D
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for motion_block,audio_block in zip(self.motion_blocks,self.audio_blocks):
motion_hidden_states,ref_motion_hidden_states,randomref_motion_hidden_states = motion_block(
motion_hidden_states,
ref_motion_hidden_states,
randomref_motion_hidden_states,
temb=emb,
motion_token = L,
)
motion_hidden_states,ref_motion_hidden_states = audio_block(
motion_hidden_states,
ref_motion_hidden_states,
audio_hidden_states,
temb=emb,
)
motion_hidden_states = self.norm_final(motion_hidden_states)
# 6. Final block
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb)
motion_hidden_states = self.proj_out(motion_hidden_states) # [N,L1,D]
# 7. Unpatchify
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F)
return output # N,F,L,D
# Audio + dwpose
class A2MTransformer_CrossAttn_Audio_Pose(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
motion_num_token: int = 12,
motion_inchannel: int = 128,
motion_frames: int = 128,
audio_window : Optional[int] = 12,
audio_in_channels: Optional[int] = 128,
out_channels: Optional[int] = 128,
pose_height : int = 32,
pose_width : int = 32,
pose_inchannel : int = 4,
pose_patch_size : int = 2,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = out_channels
self.motion_frames = motion_frames
self.motion_num_token = motion_num_token
self.motion_num_tokens = motion_num_token * motion_frames
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim)
self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 1d positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange( self.motion_num_tokens))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 4. 2d positional embeddings
iph = pose_height // pose_patch_size
ipw = pose_width // pose_patch_size
itl = iph * ipw
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.motion_blocks = nn.ModuleList(
[
A2MMotionSelfAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.audio_blocks = nn.ModuleList(
[
A2MCrossAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.pose_blocks = nn.ModuleList(
[
A2MCrossAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=hidden_dim*2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
refmotion_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor,
pose_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (N,F,L,D)
refmotion_hidden_states : (N,L,D)
audio_hidden_states : (N,F+1,W,D)
pose_hidden_states : (N,F+1,c,h,w)
"""
N,T,L,D = refmotion_hidden_states.shape
N,F,L,D = motion_hidden_states.shape
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) # (batch,time_embed_dim) (8,512)
# 2. Patch embedding
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d')
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,FL,D]
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n f l d -> n (f l) d')
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) # [N,TL,D]
audio_hidden_states = self.audio_embed(audio_hidden_states) # [N,T+F,W,D]
pose_hidden_states = self.pose_embed(pose_hidden_states.flatten(0,1)) # [N(T+F),S,D]
# 3. Position embedding
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:]
pose_hidden_states = pose_hidden_states + self.pose_pos_embedding
pose_hidden_states = einops.rearrange(pose_hidden_states,'(n f) l d -> n f l d',n=N)
self.embedding_dropout(motion_hidden_states)
# 4. Transformer blocks
for motion_block,audio_block,pose_block in zip(self.motion_blocks,self.audio_blocks,self.pose_blocks):
motion_hidden_states,ref_motion_hidden_states = motion_block(
motion_hidden_states,
ref_motion_hidden_states,
temb=emb,
)
motion_hidden_states,ref_motion_hidden_states = audio_block(
motion_hidden_states,
ref_motion_hidden_states,
audio_hidden_states,
temb=emb,
)
motion_hidden_states,ref_motion_hidden_states = pose_block(
motion_hidden_states,
ref_motion_hidden_states,
pose_hidden_states,
temb=emb,
)
motion_hidden_states = self.norm_final(motion_hidden_states)
# 6. Final block
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb)
motion_hidden_states = self.proj_out(motion_hidden_states) # [N,L1,D]
# 7. Unpatchify
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F)
return output # N,F,L,D
# dwpose
class A2MTransformer_CrossAttn_Pose(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
motion_num_token: int = 12,
motion_inchannel: int = 128,
motion_frames: int = 128,
out_channels: Optional[int] = 128,
pose_height : int = 32,
pose_width : int = 32,
pose_inchannel : int = 4,
pose_patch_size : int = 2,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# 1. Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channels = out_channels
self.motion_frames = motion_frames
self.motion_num_token = motion_num_token
self.motion_num_tokens = motion_num_token * motion_frames
# 2. Patch embedding (N,F,C,H,W) -> (B,S,D)
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim)
self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True)
self.embedding_dropout = nn.Dropout(dropout)
# 3. 1d positional embeddings
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens))
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False)
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding))
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False)
# 4. 2d positional embeddings
iph = pose_height // pose_patch_size
ipw = pose_width // pose_patch_size
itl = iph * ipw
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) # (iph*ipw,D)
image_pos_embedding = torch.from_numpy(image_pos_embedding) # (iph*ipw,D)
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False)
pos_embedding.data[:, :itl].copy_(image_pos_embedding)
self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False)
# 5. Time embeddings
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift)
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn)
# 6. Define spatio-temporal transformers blocks
self.motion_blocks = nn.ModuleList(
[
A2MMotionSelfAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.pose_blocks = nn.ModuleList(
[
A2MCrossAttnBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# 5. Output blocks
self.norm_out = AdaLayerNorm(
embedding_dim=time_embed_dim,
output_dim=hidden_dim*2,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
chunk_dim=1,
)
self.proj_out = nn.Linear(hidden_dim, out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
motion_hidden_states: torch.Tensor,
refmotion_hidden_states: torch.Tensor,
pose_hidden_states: torch.Tensor,
timestep: Union[int, float, torch.LongTensor], # Timesteps should be a 1d-array
timestep_cond: Optional[torch.Tensor] = None,
return_dict: bool = True,
**kwargs,
):
"""
motion_hidden_states : (N,F,L,D)
refmotion_hidden_states : (N,T,L,D)
pose_hidden_states : (N,T+F,c,h,w)
"""
N,T,L,D = refmotion_hidden_states.shape
N,F,L,D = motion_hidden_states.shape
# 1. Time embedding # Timesteps should be a 1d-array
t_emb = self.time_proj(timestep)
t_emb = t_emb.to(dtype=motion_hidden_states.dtype)
emb = self.time_embedding(t_emb, timestep_cond) # (batch,time_embed_dim) (8,512)
# 2. Patch embedding
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d')
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) # [N,FL,D]
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d')
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) # [N,TL,D]
pose_hidden_states = self.pose_embed(pose_hidden_states.flatten(0,1)) # [N(T+F),S,D]
# 3. Position embedding
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:]
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:]
self.embedding_dropout(motion_hidden_states)
pose_hidden_states = pose_hidden_states + self.pose_pos_embedding
pose_hidden_states = einops.rearrange(pose_hidden_states,'(n f) l d -> n f l d',n=N) # N T+F S D
# 4. Transformer blocks
for motion_block,pose_block in zip(self.motion_blocks,self.pose_blocks):
motion_hidden_states,ref_motion_hidden_states = motion_block(
motion_hidden_states,
ref_motion_hidden_states,
temb=emb,
) # N,FL,D | N,TL,D
motion_hidden_states,ref_motion_hidden_states = pose_block(
motion_hidden_states,
ref_motion_hidden_states,
pose_hidden_states, # N T+F S D
temb=emb,
)
motion_hidden_states = self.norm_final(motion_hidden_states)
# 6. Final block
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb)
motion_hidden_states = self.proj_out(motion_hidden_states) # [N,L1,D]
# 7. Unpatchify
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F)
return output # N,F,L,D
# ------------- A2P ---------------
class A2PTransformer(nn.Module):
_supports_gradient_checkpointing = True
def __init__(
self,
audio_window : Optional[int] = 12,
audio_in_channels: Optional[int] = 128,
pose_height : int = 32,
pose_width : int = 32,
pose_inchannel : int = 4,
pose_patch_size : int = 4,
pose_frame : int = 17,
num_attention_heads: int = 8,
attention_head_dim: int = 64,
num_layers: int = 16,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
time_embed_dim: int = 512,
dropout: float = 0.0,
attention_bias: bool = True,
temporal_compression_ratio: int = 4,
activation_fn: str = "gelu-approximate",
timestep_activation_fn: str = "silu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
spatial_interpolation_scale: float = 1.0,
temporal_interpolation_scale: float = 1.0,
):
super().__init__()
# Setting
hidden_dim = num_attention_heads * attention_head_dim
self.out_channel = pose_inchannel
self.pose_patch_size = pose_patch_size
# Patch embedding (N,F,C,H,W) -> (B,S,D)
self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True)
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim)
# pose mask token
iph = pose_height // pose_patch_size
ipw = pose_width // pose_patch_size
itl = iph * ipw
INIT_CONST = 0.02
self.pose_mask_token = nn.Parameter(torch.randn(1, itl, hidden_dim) * INIT_CONST) # 1,L,D
# 3d position embedding
spatial_pos_embedding = get_3d_sincos_pos_embed(
hidden_dim,
(iph, ipw),
pose_frame,
spatial_interpolation_scale,
temporal_interpolation_scale,
)
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) # [T*H*W, D]
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False)
pos_embedding.data.copy_(spatial_pos_embedding)
self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False)
# blocks
self.temporal_spatial_blocks = nn.ModuleList(
[
A2PTemporalSpatialBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.audio_blocks = nn.ModuleList(
[
A2PCrossAudioBlock(
dim=hidden_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine)
# Output blocks
self.proj_out = nn.Linear(hidden_dim, pose_patch_size * pose_patch_size * pose_inchannel)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
ref_pose_hidden_states: torch.Tensor,
audio_hidden_states: torch.Tensor, # N,F,W,D
**kwargs,
):
"""
ref_pose_hidden_states : (N,T,C,H,W)
audio_hidden_states : (N,T+F,W,D)
"""
N,T,C,H,W = ref_pose_hidden_states.shape
N,T_F,L,D = audio_hidden_states.shape
F = T_F - T
# Audio patch embedding
audio_hidden_states = einops.rearrange(audio_hidden_states,'n f w d -> (n f) w d') # NF,W,d
audio_hidden_states = self.audio_embed(audio_hidden_states) # NF,W,D
audio_hidden_states = einops.rearrange(audio_hidden_states,'(n f) w d -> n f w d',n=N) # N,F,W,d
# pose patch embedding
ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'n t c h w -> (n t) c h w')
ref_pose_hidden_states = self.pose_embed(ref_pose_hidden_states) # NT,L,D
ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'(n t) l d -> n t l d',n=N) # N,T,L,D
pose_mask_hidden_states = self.pose_mask_token.unsqueeze(0).repeat(N,F,1,1) # N,F,L,D
# position encoding
ref_pose_hidden_states = ref_pose_hidden_states.flatten(1,2) # N,TL,D
ref_pose_hidden_states = ref_pose_hidden_states + self.pose_pos_embedding[:,:ref_pose_hidden_states.shape[1],:]
ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'n (t l) d -> n t l d',t=T) # N,T,L,D
pose_mask_hidden_states = pose_mask_hidden_states.flatten(1,2) # N,FL,D
pose_mask_hidden_states = pose_mask_hidden_states + self.pose_pos_embedding[:,:pose_mask_hidden_states.shape[1],:]
pose_mask_hidden_states = einops.rearrange(pose_mask_hidden_states,'n (f l) d -> n f l d',f=F) # N,F,L,D
pose_hidden_states = torch.cat((ref_pose_hidden_states,pose_mask_hidden_states),dim=1) # N,T+F,L,D
# Transformer blocks
for st_block,audio_block in zip(self.temporal_spatial_blocks,self.audio_blocks):
pose_hidden_states = st_block(
pose_hidden_states,
)
pose_hidden_states = audio_block(
pose_hidden_states,
audio_hidden_states,
)
pose_hidden_states = self.norm_final(pose_hidden_states)
# 6. Final block
pose_hidden_states = self.proj_out(pose_hidden_states) # [N,F,iph*ipw,p*p*outchannel]
# 7. Unpatchify
p = self.pose_patch_size
output = pose_hidden_states.reshape(N, T_F, H // p, W // p, self.out_channel, p, p)
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) # N,F,C,H,W
return output # N,F,C,H,W