|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional, Union, Dict, Any |
|
|
import einops |
|
|
|
|
|
from diffusers.utils import is_torch_version |
|
|
|
|
|
from transformers import AutoModel |
|
|
from diffusers import CogVideoXPipeline |
|
|
|
|
|
from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler |
|
|
|
|
|
from diffusers.models.attention import Attention,FeedForward |
|
|
|
|
|
from .modules import (BasicTransformerBlock, |
|
|
PatchEmbed, |
|
|
AMDTransformerBlock, |
|
|
AdaLayerNorm, |
|
|
TransformerBlock2Condition, |
|
|
TransformerBlock2Condition_SimpleAdaLN, |
|
|
A2MMotionSelfAttnBlock, |
|
|
A2MCrossAttnBlock, |
|
|
A2PTemporalSpatialBlock, |
|
|
A2PCrossAudioBlock, |
|
|
AMDTransformerMotionBlock, |
|
|
BasicDiTBlock, |
|
|
A2MMotionSelfAttnBlockDoubleRef) |
|
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed,get_2d_sincos_pos_embed,get_1d_sincos_pos_embed_from_grid |
|
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
from diffusers.utils import BaseOutput, logging |
|
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
|
|
|
|
class MotionEncoderLearnTokenTransformer(nn.Module): |
|
|
r""" |
|
|
Motion Encoder With Learnable Token |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
img_height: int = 32, |
|
|
img_width: int = 32, |
|
|
img_inchannel: int = 4, |
|
|
img_patch_size: int = 2, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_channel:int = 128, |
|
|
need_norm_out :bool = True, |
|
|
|
|
|
num_attention_heads: int = 12, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 8, |
|
|
freq_shift: int = 0, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = img_height // img_patch_size |
|
|
ipw = img_width // img_patch_size |
|
|
itl = iph * ipw |
|
|
self.img_token_len = itl |
|
|
|
|
|
|
|
|
INIT_CONST = 0.02 |
|
|
self.motion_token = nn.Parameter(torch.randn(1, motion_token_num, motion_channel) * INIT_CONST) |
|
|
self.motion_embed = nn.Linear(motion_channel,hidden_dim) |
|
|
|
|
|
|
|
|
self.patch_embed = PatchEmbed(img_patch_size,img_inchannel,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
self.proj_out = nn.Linear(hidden_dim, motion_channel) |
|
|
self.need_norm_out = need_norm_out |
|
|
if self.need_norm_out: |
|
|
self.norm_out = nn.LayerNorm(motion_channel, eps=norm_eps, elementwise_affine=False) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
img_hidden_states: torch.Tensor, |
|
|
mask_ratio = None, |
|
|
): |
|
|
N, T, C, H, W = img_hidden_states.shape |
|
|
|
|
|
|
|
|
motion_token = self.motion_embed(self.motion_token) |
|
|
motion_token = motion_token.repeat(N*T,1,1) |
|
|
|
|
|
|
|
|
img_hidden_states = einops.rearrange(img_hidden_states, 'n t c h w -> (n t) c h w') |
|
|
img_hidden_states = self.patch_embed(img_hidden_states) |
|
|
assert self.img_token_len == img_hidden_states.shape[1] , 'img_token_len should be equal!' |
|
|
|
|
|
|
|
|
pos_embeds = self.pos_embedding[:, :self.img_token_len] |
|
|
img_hidden_states = img_hidden_states + pos_embeds |
|
|
img_hidden_states = self.embedding_dropout(img_hidden_states) |
|
|
|
|
|
|
|
|
if mask_ratio is not None: |
|
|
img_hidden_states,_,_ = self.random_masking(img_hidden_states,mask_ratio) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = torch.cat([motion_token,img_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
hidden_states = block( |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
|
|
|
|
|
|
motion_token = hidden_states[:, :motion_token.shape[1],:] |
|
|
motion_token = self.norm_final(motion_token) |
|
|
motion_token = self.proj_out(motion_token) |
|
|
if self.need_norm_out: |
|
|
motion_token = self.norm_out(motion_token) |
|
|
|
|
|
|
|
|
motion_token = einops.rearrange(motion_token, '(n t) l d -> n t l d',n=N) |
|
|
|
|
|
return motion_token |
|
|
|
|
|
def random_masking(self, x, mask_ratio): |
|
|
""" |
|
|
Perform per-sample random masking by per-sample shuffling. |
|
|
Per-sample shuffling is done by argsort random noise. |
|
|
x: [N, L, D], sequence |
|
|
""" |
|
|
N, L, D = x.shape |
|
|
len_keep = int(L * (1 - mask_ratio)) |
|
|
|
|
|
noise = torch.rand(N, L, device=x.device) |
|
|
|
|
|
|
|
|
|
|
|
ids_shuffle = torch.argsort(noise, dim=1) |
|
|
|
|
|
ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
|
|
|
|
|
|
ids_keep = ids_shuffle[:, :len_keep] |
|
|
|
|
|
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
|
|
|
|
|
|
mask = torch.ones([N, L], device=x.device) |
|
|
mask[:, :len_keep] = 0 |
|
|
|
|
|
mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
|
|
|
return x_masked, mask, ids_restore |
|
|
|
|
|
|
|
|
|
|
|
class MotionTransformer(nn.Module): |
|
|
""" |
|
|
Motion |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
motion_token_num : int = 4, |
|
|
motion_token_channel : int = 128, |
|
|
motion_frames : int = 128, |
|
|
attention_head_dim : int = 64, |
|
|
num_attention_heads: int = 16, |
|
|
num_layers: int = 8, |
|
|
|
|
|
freq_shift: int = 0, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = motion_token_channel |
|
|
self.motion_token_length = motion_token_num * motion_frames |
|
|
|
|
|
|
|
|
self.embed = nn.Linear(motion_token_channel,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_token_length)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.proj_out = nn.Linear(hidden_dim,motion_token_channel) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
): |
|
|
N,F,L,D = hidden_states.shape |
|
|
|
|
|
|
|
|
hidden_states = self.embed(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states.flatten(1,2) + self.motion_pos_embedding[:,:F*L,:] |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
hidden_states = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
hidden_states, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
else: |
|
|
hidden_states = block( |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = einops.rearrange(hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
|
|
|
class AMDReconstructTransformerModel(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_source_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L + 2 |
|
|
|
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
|
|
|
|
|
|
source_token = self.source_token.repeat(N, 1, 1) |
|
|
target_token = self.target_token.repeat(N, 1, 1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = torch.cat([image_hidden_states,motion_hidden_states],dim=1) |
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
hidden_states = block( |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
|
|
|
image_hidden_states = self.norm_final(hidden_states[:,:image_hidden_states.shape[1],:]) |
|
|
|
|
|
|
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class AMDReconstructTransformerModelSpatial(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
motion_target_num_frame : int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.target_frame = motion_target_num_frame |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame)) |
|
|
img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
|
|
|
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.spatial_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_source_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L + 2 |
|
|
|
|
|
n = N // self.target_frame |
|
|
t = self.target_frame |
|
|
l = L |
|
|
d = Cm |
|
|
|
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
|
|
|
|
|
|
source_token = self.source_token.repeat(N, 1, 1) |
|
|
target_token = self.target_token.repeat(N, 1, 1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,:t] |
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
|
|
|
img_length = image_hidden_states.shape[1] |
|
|
motion_length = motion_hidden_states.shape[1] |
|
|
for block,s_block in zip(self.transformer_blocks,self.spatial_blocks): |
|
|
hidden_states = torch.cat([image_hidden_states,motion_hidden_states],dim=1) |
|
|
hidden_states = block( |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
image_hidden_states = hidden_states[:,:img_length,:] |
|
|
motion_hidden_states = hidden_states[:,img_length:,:] |
|
|
|
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
image_hidden_states = s_block( |
|
|
hidden_states = image_hidden_states, |
|
|
) |
|
|
|
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_final(image_hidden_states) |
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class AMDReconstructSplitTransformerModel(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.zi_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
|
|
self.zt_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_source_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L + 2 |
|
|
|
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
|
|
|
zi_image_hidden_states = self.zi_image_patch_embed(image_hidden_states[:,:Ci//2,:,:]) |
|
|
zt_image_hidden_states = self.zt_image_patch_embed(image_hidden_states[:,Ci//2:,:,:]) |
|
|
|
|
|
|
|
|
source_token = self.source_token.repeat(N, 1, 1) |
|
|
target_token = self.target_token.repeat(N, 1, 1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
zi_image_hidden_states = zi_image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
zt_image_hidden_states = zt_image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = torch.cat([zt_image_hidden_states,zi_image_hidden_states,motion_hidden_states],dim=1) |
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
hidden_states = block( |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
|
|
|
image_hidden_states = self.norm_final(hidden_states[:,:zt_image_hidden_states.shape[1],:]) |
|
|
|
|
|
|
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class AMDDiffusionTransformerModel(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
AMDTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=2 * hidden_dim, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_source_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L + 2 |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
|
|
|
|
|
|
source_token = self.source_token.repeat(N, 1, 1) |
|
|
target_token = self.target_token.repeat(N, 1, 1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
motion_hidden_states, image_hidden_states = block( |
|
|
hidden_states=motion_hidden_states, |
|
|
encoder_hidden_states=image_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
image_hidden_states = self.norm_final(image_hidden_states) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
class AMDDiffusionTransformerModelDualStream(nn.Module): |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
motion_target_num_frame : int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
self.target_frame = motion_target_num_frame |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
INIT_CONST = 0.02 |
|
|
self.source_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * INIT_CONST) |
|
|
self.target_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * INIT_CONST) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2*motion_token_num+2)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2*self.target_frame * (motion_token_num+1))) |
|
|
motion_temporal_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_temporal_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_temporal_embedding",motion_temporal_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
AMDTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.motion_blocks = nn.ModuleList( |
|
|
[ |
|
|
AMDTransformerMotionBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=2 * hidden_dim, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (nt,l,d) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_target_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L |
|
|
|
|
|
n = N // self.target_frame |
|
|
t = self.target_frame |
|
|
l = L |
|
|
d = Cm |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
emb_m = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t) |
|
|
emb_m = emb_m[:,0,:] |
|
|
|
|
|
|
|
|
image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
|
|
|
source_token = self.source_token.repeat(n*t,1,1) |
|
|
target_token = self.target_token.repeat(n*t,1,1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :2*l+2] |
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'(n t) l d -> n (t l) d',n=n) |
|
|
motion_hidden_states = motion_hidden_states + self.motion_temporal_embedding[:,:t*(2*l+2)] |
|
|
|
|
|
|
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
|
|
|
|
|
|
for block,m_block in zip(self.transformer_blocks,self.motion_blocks): |
|
|
|
|
|
|
|
|
motion_hidden_states = m_block( |
|
|
hidden_states = motion_hidden_states, |
|
|
temb = emb_m, |
|
|
) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n (t l) d -> (n t) l d',t=t) |
|
|
|
|
|
|
|
|
motion_hidden_states, image_hidden_states = block( |
|
|
hidden_states=motion_hidden_states, |
|
|
encoder_hidden_states=image_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'(n t) l d -> n (t l) d',t=t) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_final(image_hidden_states) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
class AMDDiffusionTransformerModelImgSpatial(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
motion_target_num_frame : int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
self.target_frame = motion_target_num_frame |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame)) |
|
|
img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
AMDTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.spatial_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicDiTBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=2 * hidden_dim, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_source_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L + 2 |
|
|
|
|
|
n = N // self.target_frame |
|
|
t = self.target_frame |
|
|
l = L |
|
|
d = Cm |
|
|
|
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
emb_s = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t) |
|
|
emb_s = emb_s[:,:1,:].repeat(1,image_hidden_states.shape[1],1) |
|
|
emb_s = emb_s.flatten(0,1) |
|
|
|
|
|
|
|
|
source_token = self.source_token.repeat(N, 1, 1) |
|
|
target_token = self.target_token.repeat(N, 1, 1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,:t] |
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for block,s_block in zip(self.transformer_blocks,self.spatial_blocks): |
|
|
motion_hidden_states, image_hidden_states = block( |
|
|
hidden_states=motion_hidden_states, |
|
|
encoder_hidden_states=image_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
image_hidden_states = s_block( |
|
|
hidden_states = image_hidden_states, |
|
|
temb = emb_s, |
|
|
) |
|
|
|
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_final(image_hidden_states) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
class AMDDiffusionTransformerModelImgSpatialDoubleRef(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 20, |
|
|
attention_head_dim: int = 64, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 12, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
image_patch_size: int = 2, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
|
|
|
motion_token_num:int = 12, |
|
|
motion_in_channels: Optional[int] = 128, |
|
|
motion_target_num_frame : int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
self.image_patch_size = image_patch_size |
|
|
self.out_channels = out_channels |
|
|
self.target_frame = motion_target_num_frame |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels*2,embed_dim=hidden_dim, bias=True) |
|
|
self.lastref_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame + 1)) |
|
|
img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
AMDTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.spatial_blocks = nn.ModuleList( |
|
|
[ |
|
|
BasicDiTBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=2 * hidden_dim, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_source_hidden_states: torch.Tensor, |
|
|
motion_target_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
randomref_image_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
randomref_image_hidden_states : (b,c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,L,Cm = motion_source_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = 2*L + 2 |
|
|
|
|
|
n = N // self.target_frame |
|
|
t = self.target_frame |
|
|
l = L |
|
|
d = Cm |
|
|
|
|
|
|
|
|
lastref_image_hidden_states = image_hidden_states[:,:Ci//2,:,:] |
|
|
noised_image_hidden_states = image_hidden_states[:,Ci//2:,:,:] |
|
|
randomref_image_hidden_states = randomref_image_hidden_states |
|
|
|
|
|
|
|
|
motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
|
|
motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
image_hidden_states = self.image_patch_embed(torch.cat([randomref_image_hidden_states,noised_image_hidden_states],dim=1)) |
|
|
lastref_image_hidden_states = self.lastref_image_patch_embed(lastref_image_hidden_states) |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
emb_s = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t) |
|
|
emb_s = emb_s[:,:1,:].repeat(1,image_hidden_states.shape[1],1) |
|
|
emb_s = emb_s.flatten(0,1) |
|
|
|
|
|
|
|
|
source_token = self.source_token.repeat(N, 1, 1) |
|
|
target_token = self.target_token.repeat(N, 1, 1) |
|
|
motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
|
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
|
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,1:t+1] |
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
|
lastref_image_hidden_states = lastref_image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
lastref_image_hidden_states = einops.rearrange(lastref_image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
lastref_image_hidden_states = lastref_image_hidden_states[:,:1,:] |
|
|
lastref_image_hidden_states = lastref_image_hidden_states + self.img_temporal_embedding[:,:1] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for block,s_block in zip(self.transformer_blocks,self.spatial_blocks): |
|
|
|
|
|
motion_hidden_states, image_hidden_states = block( |
|
|
hidden_states=motion_hidden_states, |
|
|
encoder_hidden_states=image_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
|
|
s_block_input_hidden_states = torch.cat([lastref_image_hidden_states,image_hidden_states],dim=1) |
|
|
|
|
|
s_block_output = s_block( |
|
|
hidden_states = s_block_input_hidden_states, |
|
|
temb = emb_s, |
|
|
) |
|
|
|
|
|
lastref_image_hidden_states = s_block_output[:,:1,:] |
|
|
image_hidden_states = s_block_output[:,1:,:] |
|
|
|
|
|
image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_final(image_hidden_states) |
|
|
|
|
|
|
|
|
image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
|
|
image_hidden_states = self.proj_out(image_hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class AMDDiffusionTransformerModelSplitInput(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 30, |
|
|
attention_head_dim: int = 64, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
motion_in_channels: Optional[int] = 16, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 16, |
|
|
|
|
|
image_width: int = 64, |
|
|
image_height: int = 64, |
|
|
motion_width: int = 8, |
|
|
motion_height: int = 8, |
|
|
image_patch_size: int = 2, |
|
|
motion_patch_size: int = 1, |
|
|
motion_frames: int = 15, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = 2*iph * ipw |
|
|
mph = motion_height // motion_patch_size |
|
|
mpw = motion_width // motion_patch_size |
|
|
mtl = mph * mpw * motion_frames |
|
|
self.max_seq_length = itl + mtl |
|
|
|
|
|
self.image_patch_size = image_patch_size |
|
|
self.motion_patch_size = motion_patch_size |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.zi_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
|
|
self.zt_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = PatchEmbed(patch_size=motion_patch_size,in_channels=motion_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
|
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
spatial_pos_embedding = get_3d_sincos_pos_embed( |
|
|
hidden_dim, |
|
|
(iph, ipw), |
|
|
2, |
|
|
spatial_interpolation_scale, |
|
|
temporal_interpolation_scale, |
|
|
) |
|
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
|
|
|
|
|
|
|
|
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
|
|
pos_embedding.data.copy_(spatial_pos_embedding) |
|
|
self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
AMDTransformerBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=2 * hidden_dim, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
image_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
image_hidden_states : (b,2c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
N,Ci,Hi,Wi = image_hidden_states.shape |
|
|
N,Cm,Hm,Wm = motion_hidden_states.shape |
|
|
image_seq_length = 2 * Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = Hm * Wm // (self.motion_patch_size**2) |
|
|
|
|
|
zi = image_hidden_states[:,:Ci//2] |
|
|
zt = image_hidden_states[:,Ci//2:] |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
zi_hidden_states = self.zi_patch_embed(zi) |
|
|
zt_hidden_states = self.zt_patch_embed(zt) |
|
|
image_hidden_states = torch.cat((zi_hidden_states,zt_hidden_states),dim=1) |
|
|
|
|
|
assert image_seq_length == image_hidden_states.shape[1] , f"image_seq_length : {image_seq_length} != image_hidden_states.shape[1] : {image_hidden_states.shape[1]}" |
|
|
|
|
|
|
|
|
image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
motion_hidden_states, image_hidden_states = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
motion_hidden_states, |
|
|
image_hidden_states, |
|
|
emb, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
else: |
|
|
motion_hidden_states, image_hidden_states = block( |
|
|
hidden_states=motion_hidden_states, |
|
|
encoder_hidden_states=image_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
pre = image_hidden_states[:,image_seq_length//2:] |
|
|
pre = self.norm_final(pre) |
|
|
|
|
|
|
|
|
pre = self.norm_out(pre, temb=emb) |
|
|
pre = self.proj_out(pre) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = pre.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
class DiffusionTransformerModel2Condition(nn.Module): |
|
|
""" |
|
|
Diffusion Transformer |
|
|
""" |
|
|
|
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
num_attention_heads: int = 30, |
|
|
attention_head_dim: int = 64, |
|
|
image_in_channels: Optional[int] = 4, |
|
|
motion_in_channels: Optional[int] = 16, |
|
|
out_channels: Optional[int] = 4, |
|
|
num_layers: int = 16, |
|
|
|
|
|
image_width: int = 32, |
|
|
image_height: int = 32, |
|
|
motion_width: int = 8, |
|
|
motion_height: int = 8, |
|
|
image_patch_size: int = 2, |
|
|
motion_patch_size: int = 1, |
|
|
motion_frames: int = 15, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
""" |
|
|
|
|
|
Traning: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
|
|
|
Inference: |
|
|
Z(N,1,C,H,W) |
|
|
Motion(N,k,d,h,w) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
iph = image_height // image_patch_size |
|
|
ipw = image_width // image_patch_size |
|
|
itl = iph * ipw |
|
|
|
|
|
mph = motion_height // motion_patch_size |
|
|
mpw = motion_width // motion_patch_size |
|
|
mtl = mph * mpw * motion_frames |
|
|
|
|
|
self.max_seq_length = 2 * itl + mtl |
|
|
|
|
|
self.image_patch_size = image_patch_size |
|
|
self.motion_patch_size = motion_patch_size |
|
|
self.out_channels = out_channels |
|
|
|
|
|
|
|
|
self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.refimg_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
self.motion_patch_embed = PatchEmbed(patch_size=motion_patch_size,in_channels=motion_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
|
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
spatial_pos_embedding = get_3d_sincos_pos_embed( |
|
|
hidden_dim, |
|
|
(iph, iph), |
|
|
2, |
|
|
spatial_interpolation_scale, |
|
|
temporal_interpolation_scale, |
|
|
) |
|
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
|
|
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
|
|
pos_embedding.data.copy_(spatial_pos_embedding) |
|
|
self.register_buffer("img_pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
spatial_pos_embedding = get_3d_sincos_pos_embed( |
|
|
hidden_dim, |
|
|
(mph, mph), |
|
|
motion_frames, |
|
|
spatial_interpolation_scale, |
|
|
temporal_interpolation_scale, |
|
|
) |
|
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
|
|
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
|
|
pos_embedding.data.copy_(spatial_pos_embedding) |
|
|
self.register_buffer("motion_pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
TransformerBlock2Condition( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=2 * hidden_dim, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
refimg_hidden_states: torch.Tensor, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
hidden_states : (b,c,H,W) |
|
|
motion_hidden_states : (b,d,h,w) |
|
|
refimg_hidden_states : (b,c,H,W) |
|
|
time_step : (b,) |
|
|
|
|
|
""" |
|
|
|
|
|
N,Ci,Hi,Wi = hidden_states.shape |
|
|
N,Cm,Hm,Wm = motion_hidden_states.shape |
|
|
image_seq_length = Hi * Wi // (self.image_patch_size**2) |
|
|
motion_seq_length = Hm * Wm // (self.motion_patch_size**2) |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
hidden_states = self.image_patch_embed(hidden_states) |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
refimg_hidden_states = self.refimg_patch_embed(refimg_hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = hidden_states + self.img_pos_embedding[:,:image_seq_length,:] |
|
|
refimg_hidden_states = refimg_hidden_states + self.img_pos_embedding[:, image_seq_length:,:] |
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length,:] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
motion_hidden_states, refimg_hidden_states,motion_hidden_states = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
hidden_states, |
|
|
refimg_hidden_states, |
|
|
motion_hidden_states, |
|
|
emb, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
else: |
|
|
hidden_states, refimg_hidden_states,motion_hidden_states = block( |
|
|
hidden_states, |
|
|
refimg_hidden_states, |
|
|
motion_hidden_states, |
|
|
emb, |
|
|
) |
|
|
|
|
|
hidden_states = self.norm_final(hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = self.norm_out(hidden_states, temb=emb) |
|
|
hidden_states = self.proj_out(hidden_states) |
|
|
|
|
|
|
|
|
p = self.image_patch_size |
|
|
output = hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioMitionref_LearnableToken(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
motion_num_token: int = 12, |
|
|
motion_inchannel: int = 128, |
|
|
motion_frames: int = 128, |
|
|
|
|
|
extra_in_channels: Optional[int] = 768, |
|
|
out_channels: Optional[int] = 128, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = out_channels |
|
|
self.motion_frames = motion_frames |
|
|
self.motion_num_token = motion_num_token |
|
|
self.motion_num_tokens = motion_num_token * motion_frames |
|
|
|
|
|
|
|
|
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.extra_embed = nn.Linear(extra_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_token + self.motion_num_tokens)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_frames)) |
|
|
audio_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
audio_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("audio_pos_embedding",audio_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
TransformerBlock2Condition( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=hidden_dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
refmotion_hidden_states: torch.Tensor, |
|
|
extra_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (N,F,L,D) |
|
|
refmotion_hidden_states : (N,L,D) |
|
|
pose_hidden_states: (N,C,H,W) |
|
|
extra_hidden_states : (N,F,D) 这个需要提前做 position encoding if needed ,前面两个固定在这里做position encoding |
|
|
""" |
|
|
assert motion_hidden_states.shape[1] == extra_hidden_states.shape[1],f"motion {motion_hidden_states.shape} ,audio{extra_hidden_states.shape}" |
|
|
|
|
|
N,L,D = refmotion_hidden_states.shape |
|
|
N,F,L,D = motion_hidden_states.shape |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
ref_motion_hidden_states = self.refmotion_patch_embed(refmotion_hidden_states) |
|
|
extra_hidden_states = self.extra_embed(extra_hidden_states) |
|
|
|
|
|
|
|
|
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:L,:] |
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,L:,:] |
|
|
extra_hidden_states = extra_hidden_states + self.audio_pos_embedding |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
extra_hidden_states, |
|
|
emb, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
else: |
|
|
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
extra_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
|
|
motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
|
|
|
|
|
|
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class AudioMitionref_LearnableToken_SimpleAdaLN(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
motion_num_token: int = 12, |
|
|
motion_inchannel: int = 128, |
|
|
motion_frames: int = 128, |
|
|
|
|
|
extra_in_channels: Optional[int] = 768, |
|
|
out_channels: Optional[int] = 128, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = out_channels |
|
|
self.motion_frames = motion_frames |
|
|
self.motion_num_token = motion_num_token |
|
|
self.motion_num_tokens = motion_num_token * motion_frames |
|
|
|
|
|
|
|
|
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.extra_embed = nn.Linear(extra_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_token + self.motion_num_tokens)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_frames)) |
|
|
audio_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
audio_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("audio_pos_embedding",audio_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.transformer_blocks = nn.ModuleList( |
|
|
[ |
|
|
TransformerBlock2Condition_SimpleAdaLN( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=hidden_dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
refmotion_hidden_states: torch.Tensor, |
|
|
extra_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (N,F,L,D) |
|
|
refmotion_hidden_states : (N,L,D) |
|
|
pose_hidden_states: (N,C,H,W) |
|
|
extra_hidden_states : (N,F,D) 这个需要提前做 position encoding if needed ,前面两个固定在这里做position encoding |
|
|
""" |
|
|
assert motion_hidden_states.shape[1] == extra_hidden_states.shape[1],f"motion {motion_hidden_states.shape} ,audio{extra_hidden_states.shape}" |
|
|
|
|
|
N,L,D = refmotion_hidden_states.shape |
|
|
N,F,L,D = motion_hidden_states.shape |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
ref_motion_hidden_states = self.refmotion_patch_embed(refmotion_hidden_states) |
|
|
extra_hidden_states = self.extra_embed(extra_hidden_states) |
|
|
|
|
|
|
|
|
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:L,:] |
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,L:,:] |
|
|
extra_hidden_states = extra_hidden_states + self.audio_pos_embedding |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for i, block in enumerate(self.transformer_blocks): |
|
|
if self.training and self.gradient_checkpointing: |
|
|
|
|
|
def create_custom_forward(module): |
|
|
def custom_forward(*inputs): |
|
|
return module(*inputs) |
|
|
|
|
|
return custom_forward |
|
|
|
|
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
|
|
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = torch.utils.checkpoint.checkpoint( |
|
|
create_custom_forward(block), |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
extra_hidden_states, |
|
|
emb, |
|
|
**ckpt_kwargs, |
|
|
) |
|
|
else: |
|
|
motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
extra_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
|
|
motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
|
|
|
|
|
|
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class A2MTransformer_CrossAttn_Audio(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
motion_num_token: int = 12, |
|
|
motion_inchannel: int = 128, |
|
|
motion_frames: int = 128, |
|
|
|
|
|
audio_window : Optional[int] = 12, |
|
|
audio_in_channels: Optional[int] = 128, |
|
|
out_channels: Optional[int] = 128, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = out_channels |
|
|
self.motion_frames = motion_frames |
|
|
self.motion_num_token = motion_num_token |
|
|
self.motion_num_tokens = motion_num_token * motion_frames |
|
|
|
|
|
|
|
|
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.motion_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MMotionSelfAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.audio_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MCrossAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=hidden_dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
refmotion_hidden_states: torch.Tensor, |
|
|
audio_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (N,F,L,D) |
|
|
refmotion_hidden_states : (N,T,L,D) |
|
|
audio_hidden_states : (N,T+F,W,D) |
|
|
""" |
|
|
|
|
|
N,T,L,D = refmotion_hidden_states.shape |
|
|
N,F,L,D = motion_hidden_states.shape |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d') |
|
|
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
|
|
audio_hidden_states = self.audio_embed(audio_hidden_states) |
|
|
|
|
|
|
|
|
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for motion_block,audio_block in zip(self.motion_blocks,self.audio_blocks): |
|
|
motion_hidden_states,ref_motion_hidden_states = motion_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = audio_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
audio_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
|
|
motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
|
|
|
|
|
|
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return output |
|
|
|
|
|
class A2MTransformer_CrossAttn_Audio_DoubleRef(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
motion_num_token: int = 12, |
|
|
motion_inchannel: int = 128, |
|
|
motion_frames: int = 128, |
|
|
|
|
|
audio_window : Optional[int] = 12, |
|
|
audio_in_channels: Optional[int] = 128, |
|
|
out_channels: Optional[int] = 128, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = out_channels |
|
|
self.motion_frames = motion_frames |
|
|
self.motion_num_token = motion_num_token |
|
|
self.motion_num_tokens = motion_num_token * motion_frames |
|
|
|
|
|
|
|
|
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.motion_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MMotionSelfAttnBlockDoubleRef( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
last_layer = (i==num_layers-1), |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for i in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.audio_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MCrossAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=hidden_dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
refmotion_hidden_states: torch.Tensor, |
|
|
randomrefmotion_hidden_states: torch.Tensor, |
|
|
audio_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (N,F,L,D) |
|
|
refmotion_hidden_states : (N,T,L,D) |
|
|
randomrefmotion_hidden_states : (N,S,L,D) |
|
|
audio_hidden_states : (N,T+F,W,D) |
|
|
""" |
|
|
|
|
|
N,T,L,D = refmotion_hidden_states.shape |
|
|
N,F,L,D = motion_hidden_states.shape |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d') |
|
|
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
|
|
randomref_motion_hidden_states = einops.rearrange(randomrefmotion_hidden_states,'n s l d -> n (s l) d') |
|
|
randomref_motion_hidden_states = self.refmotion_patch_embed(randomref_motion_hidden_states) |
|
|
audio_hidden_states = self.audio_embed(audio_hidden_states) |
|
|
|
|
|
|
|
|
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
|
|
randomref_motion_hidden_states = randomref_motion_hidden_states.repeat_interleave(F, dim=0) |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for motion_block,audio_block in zip(self.motion_blocks,self.audio_blocks): |
|
|
motion_hidden_states,ref_motion_hidden_states,randomref_motion_hidden_states = motion_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
randomref_motion_hidden_states, |
|
|
temb=emb, |
|
|
motion_token = L, |
|
|
) |
|
|
|
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = audio_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
audio_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
|
|
motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
|
|
|
|
|
|
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class A2MTransformer_CrossAttn_Audio_Pose(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
motion_num_token: int = 12, |
|
|
motion_inchannel: int = 128, |
|
|
motion_frames: int = 128, |
|
|
|
|
|
audio_window : Optional[int] = 12, |
|
|
audio_in_channels: Optional[int] = 128, |
|
|
out_channels: Optional[int] = 128, |
|
|
|
|
|
pose_height : int = 32, |
|
|
pose_width : int = 32, |
|
|
pose_inchannel : int = 4, |
|
|
pose_patch_size : int = 2, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = out_channels |
|
|
self.motion_frames = motion_frames |
|
|
self.motion_num_token = motion_num_token |
|
|
self.motion_num_tokens = motion_num_token * motion_frames |
|
|
|
|
|
|
|
|
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
|
|
self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange( self.motion_num_tokens)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
iph = pose_height // pose_patch_size |
|
|
ipw = pose_width // pose_patch_size |
|
|
itl = iph * ipw |
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.motion_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MMotionSelfAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.audio_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MCrossAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.pose_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MCrossAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=hidden_dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
refmotion_hidden_states: torch.Tensor, |
|
|
audio_hidden_states: torch.Tensor, |
|
|
pose_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (N,F,L,D) |
|
|
refmotion_hidden_states : (N,L,D) |
|
|
audio_hidden_states : (N,F+1,W,D) |
|
|
pose_hidden_states : (N,F+1,c,h,w) |
|
|
""" |
|
|
|
|
|
N,T,L,D = refmotion_hidden_states.shape |
|
|
N,F,L,D = motion_hidden_states.shape |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n f l d -> n (f l) d') |
|
|
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
|
|
|
|
|
audio_hidden_states = self.audio_embed(audio_hidden_states) |
|
|
pose_hidden_states = self.pose_embed(pose_hidden_states.flatten(0,1)) |
|
|
|
|
|
|
|
|
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
|
|
|
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
|
|
|
|
|
pose_hidden_states = pose_hidden_states + self.pose_pos_embedding |
|
|
pose_hidden_states = einops.rearrange(pose_hidden_states,'(n f) l d -> n f l d',n=N) |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
|
|
|
|
|
|
for motion_block,audio_block,pose_block in zip(self.motion_blocks,self.audio_blocks,self.pose_blocks): |
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = motion_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = audio_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
audio_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = pose_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
pose_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
|
|
motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
|
|
|
|
|
|
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
class A2MTransformer_CrossAttn_Pose(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
motion_num_token: int = 12, |
|
|
motion_inchannel: int = 128, |
|
|
motion_frames: int = 128, |
|
|
|
|
|
out_channels: Optional[int] = 128, |
|
|
|
|
|
pose_height : int = 32, |
|
|
pose_width : int = 32, |
|
|
pose_inchannel : int = 4, |
|
|
pose_patch_size : int = 2, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channels = out_channels |
|
|
self.motion_frames = motion_frames |
|
|
self.motion_num_token = motion_num_token |
|
|
self.motion_num_tokens = motion_num_token * motion_frames |
|
|
|
|
|
|
|
|
self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
|
|
self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True) |
|
|
self.embedding_dropout = nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens)) |
|
|
motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
|
|
motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
|
|
self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
|
|
|
|
|
|
iph = pose_height // pose_patch_size |
|
|
ipw = pose_width // pose_patch_size |
|
|
itl = iph * ipw |
|
|
image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
|
|
image_pos_embedding = torch.from_numpy(image_pos_embedding) |
|
|
pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
|
|
pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
|
|
self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
|
|
self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
|
|
|
|
|
|
self.motion_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MMotionSelfAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.pose_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2MCrossAttnBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
time_embed_dim=time_embed_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.norm_out = AdaLayerNorm( |
|
|
embedding_dim=time_embed_dim, |
|
|
output_dim=hidden_dim*2, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
chunk_dim=1, |
|
|
) |
|
|
self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
motion_hidden_states: torch.Tensor, |
|
|
refmotion_hidden_states: torch.Tensor, |
|
|
pose_hidden_states: torch.Tensor, |
|
|
timestep: Union[int, float, torch.LongTensor], |
|
|
timestep_cond: Optional[torch.Tensor] = None, |
|
|
return_dict: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
motion_hidden_states : (N,F,L,D) |
|
|
refmotion_hidden_states : (N,T,L,D) |
|
|
pose_hidden_states : (N,T+F,c,h,w) |
|
|
""" |
|
|
|
|
|
|
|
|
N,T,L,D = refmotion_hidden_states.shape |
|
|
N,F,L,D = motion_hidden_states.shape |
|
|
|
|
|
|
|
|
t_emb = self.time_proj(timestep) |
|
|
t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
|
|
emb = self.time_embedding(t_emb, timestep_cond) |
|
|
|
|
|
|
|
|
motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
|
|
motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
|
|
ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d') |
|
|
ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
|
|
pose_hidden_states = self.pose_embed(pose_hidden_states.flatten(0,1)) |
|
|
|
|
|
|
|
|
ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
|
|
motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
|
|
self.embedding_dropout(motion_hidden_states) |
|
|
pose_hidden_states = pose_hidden_states + self.pose_pos_embedding |
|
|
pose_hidden_states = einops.rearrange(pose_hidden_states,'(n f) l d -> n f l d',n=N) |
|
|
|
|
|
|
|
|
for motion_block,pose_block in zip(self.motion_blocks,self.pose_blocks): |
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = motion_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states,ref_motion_hidden_states = pose_block( |
|
|
motion_hidden_states, |
|
|
ref_motion_hidden_states, |
|
|
pose_hidden_states, |
|
|
temb=emb, |
|
|
) |
|
|
|
|
|
motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
|
|
|
|
|
|
motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
|
|
motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
|
|
|
|
|
|
output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class A2PTransformer(nn.Module): |
|
|
_supports_gradient_checkpointing = True |
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
audio_window : Optional[int] = 12, |
|
|
audio_in_channels: Optional[int] = 128, |
|
|
|
|
|
pose_height : int = 32, |
|
|
pose_width : int = 32, |
|
|
pose_inchannel : int = 4, |
|
|
pose_patch_size : int = 4, |
|
|
pose_frame : int = 17, |
|
|
|
|
|
num_attention_heads: int = 8, |
|
|
attention_head_dim: int = 64, |
|
|
num_layers: int = 16, |
|
|
|
|
|
flip_sin_to_cos: bool = True, |
|
|
freq_shift: int = 0, |
|
|
time_embed_dim: int = 512, |
|
|
|
|
|
dropout: float = 0.0, |
|
|
attention_bias: bool = True, |
|
|
temporal_compression_ratio: int = 4, |
|
|
|
|
|
activation_fn: str = "gelu-approximate", |
|
|
timestep_activation_fn: str = "silu", |
|
|
norm_elementwise_affine: bool = True, |
|
|
norm_eps: float = 1e-5, |
|
|
spatial_interpolation_scale: float = 1.0, |
|
|
temporal_interpolation_scale: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
hidden_dim = num_attention_heads * attention_head_dim |
|
|
self.out_channel = pose_inchannel |
|
|
self.pose_patch_size = pose_patch_size |
|
|
|
|
|
|
|
|
self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True) |
|
|
self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
|
|
|
|
|
|
|
|
iph = pose_height // pose_patch_size |
|
|
ipw = pose_width // pose_patch_size |
|
|
itl = iph * ipw |
|
|
INIT_CONST = 0.02 |
|
|
self.pose_mask_token = nn.Parameter(torch.randn(1, itl, hidden_dim) * INIT_CONST) |
|
|
|
|
|
|
|
|
spatial_pos_embedding = get_3d_sincos_pos_embed( |
|
|
hidden_dim, |
|
|
(iph, ipw), |
|
|
pose_frame, |
|
|
spatial_interpolation_scale, |
|
|
temporal_interpolation_scale, |
|
|
) |
|
|
spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
|
|
pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
|
|
pos_embedding.data.copy_(spatial_pos_embedding) |
|
|
self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False) |
|
|
|
|
|
|
|
|
|
|
|
self.temporal_spatial_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2PTemporalSpatialBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.audio_blocks = nn.ModuleList( |
|
|
[ |
|
|
A2PCrossAudioBlock( |
|
|
dim=hidden_dim, |
|
|
num_attention_heads=num_attention_heads, |
|
|
attention_head_dim=attention_head_dim, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
attention_bias=attention_bias, |
|
|
norm_elementwise_affine=norm_elementwise_affine, |
|
|
norm_eps=norm_eps, |
|
|
) |
|
|
for _ in range(num_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
|
|
|
|
|
|
self.proj_out = nn.Linear(hidden_dim, pose_patch_size * pose_patch_size * pose_inchannel) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
self.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
ref_pose_hidden_states: torch.Tensor, |
|
|
audio_hidden_states: torch.Tensor, |
|
|
**kwargs, |
|
|
): |
|
|
""" |
|
|
ref_pose_hidden_states : (N,T,C,H,W) |
|
|
audio_hidden_states : (N,T+F,W,D) |
|
|
""" |
|
|
|
|
|
N,T,C,H,W = ref_pose_hidden_states.shape |
|
|
N,T_F,L,D = audio_hidden_states.shape |
|
|
F = T_F - T |
|
|
|
|
|
|
|
|
audio_hidden_states = einops.rearrange(audio_hidden_states,'n f w d -> (n f) w d') |
|
|
audio_hidden_states = self.audio_embed(audio_hidden_states) |
|
|
audio_hidden_states = einops.rearrange(audio_hidden_states,'(n f) w d -> n f w d',n=N) |
|
|
|
|
|
|
|
|
ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'n t c h w -> (n t) c h w') |
|
|
ref_pose_hidden_states = self.pose_embed(ref_pose_hidden_states) |
|
|
ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'(n t) l d -> n t l d',n=N) |
|
|
pose_mask_hidden_states = self.pose_mask_token.unsqueeze(0).repeat(N,F,1,1) |
|
|
|
|
|
|
|
|
ref_pose_hidden_states = ref_pose_hidden_states.flatten(1,2) |
|
|
ref_pose_hidden_states = ref_pose_hidden_states + self.pose_pos_embedding[:,:ref_pose_hidden_states.shape[1],:] |
|
|
ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'n (t l) d -> n t l d',t=T) |
|
|
|
|
|
pose_mask_hidden_states = pose_mask_hidden_states.flatten(1,2) |
|
|
pose_mask_hidden_states = pose_mask_hidden_states + self.pose_pos_embedding[:,:pose_mask_hidden_states.shape[1],:] |
|
|
pose_mask_hidden_states = einops.rearrange(pose_mask_hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
pose_hidden_states = torch.cat((ref_pose_hidden_states,pose_mask_hidden_states),dim=1) |
|
|
|
|
|
|
|
|
for st_block,audio_block in zip(self.temporal_spatial_blocks,self.audio_blocks): |
|
|
|
|
|
pose_hidden_states = st_block( |
|
|
pose_hidden_states, |
|
|
) |
|
|
|
|
|
pose_hidden_states = audio_block( |
|
|
pose_hidden_states, |
|
|
audio_hidden_states, |
|
|
) |
|
|
|
|
|
pose_hidden_states = self.norm_final(pose_hidden_states) |
|
|
|
|
|
|
|
|
pose_hidden_states = self.proj_out(pose_hidden_states) |
|
|
|
|
|
|
|
|
p = self.pose_patch_size |
|
|
output = pose_hidden_states.reshape(N, T_F, H // p, W // p, self.out_channel, p, p) |
|
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
|
|
|
|
|
return output |
|
|
|