| import torch |
| import torch.nn as nn |
| from typing import Optional, Union, Dict, Any |
| import einops |
|
|
| from diffusers.utils import is_torch_version |
|
|
| from transformers import AutoModel |
| from diffusers import CogVideoXPipeline |
|
|
| from diffusers import CogVideoXPipeline, CogVideoXDDIMScheduler |
|
|
| from diffusers.models.attention import Attention,FeedForward |
|
|
| from .modules import (BasicTransformerBlock, |
| PatchEmbed, |
| AMDTransformerBlock, |
| AdaLayerNorm, |
| TransformerBlock2Condition, |
| TransformerBlock2Condition_SimpleAdaLN, |
| A2MMotionSelfAttnBlock, |
| A2MCrossAttnBlock, |
| A2PTemporalSpatialBlock, |
| A2PCrossAudioBlock, |
| AMDTransformerMotionBlock, |
| BasicDiTBlock, |
| A2MMotionSelfAttnBlockDoubleRef) |
| from diffusers.models.embeddings import TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed,get_2d_sincos_pos_embed,get_1d_sincos_pos_embed_from_grid |
| from diffusers.configuration_utils import ConfigMixin, register_to_config |
| from diffusers.utils import BaseOutput, logging |
| from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
| from diffusers.models.modeling_utils import ModelMixin |
|
|
| |
| class MotionEncoderLearnTokenTransformer(nn.Module): |
| r""" |
| Motion Encoder With Learnable Token |
| """ |
|
|
| def __init__( |
| self, |
| |
| img_height: int = 32, |
| img_width: int = 32, |
| img_inchannel: int = 4, |
| img_patch_size: int = 2, |
| |
| motion_token_num:int = 12, |
| motion_channel:int = 128, |
| need_norm_out :bool = True, |
| |
| num_attention_heads: int = 12, |
| attention_head_dim: int = 64, |
| num_layers: int = 8, |
| freq_shift: int = 0, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| activation_fn: str = "gelu-approximate", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
|
|
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = img_height // img_patch_size |
| ipw = img_width // img_patch_size |
| itl = iph * ipw |
| self.img_token_len = itl |
|
|
| |
| INIT_CONST = 0.02 |
| self.motion_token = nn.Parameter(torch.randn(1, motion_token_num, motion_channel) * INIT_CONST) |
| self.motion_embed = nn.Linear(motion_channel,hidden_dim) |
| |
| |
| self.patch_embed = PatchEmbed(img_patch_size,img_inchannel,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| |
| self.norm_final = nn.LayerNorm(hidden_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
| self.proj_out = nn.Linear(hidden_dim, motion_channel) |
| self.need_norm_out = need_norm_out |
| if self.need_norm_out: |
| self.norm_out = nn.LayerNorm(motion_channel, eps=norm_eps, elementwise_affine=False) |
|
|
| self.gradient_checkpointing = False |
|
|
| def forward( |
| self, |
| img_hidden_states: torch.Tensor, |
| mask_ratio = None, |
| ): |
| N, T, C, H, W = img_hidden_states.shape |
|
|
| |
| motion_token = self.motion_embed(self.motion_token) |
| motion_token = motion_token.repeat(N*T,1,1) |
|
|
| |
| img_hidden_states = einops.rearrange(img_hidden_states, 'n t c h w -> (n t) c h w') |
| img_hidden_states = self.patch_embed(img_hidden_states) |
| assert self.img_token_len == img_hidden_states.shape[1] , 'img_token_len should be equal!' |
|
|
| |
| pos_embeds = self.pos_embedding[:, :self.img_token_len] |
| img_hidden_states = img_hidden_states + pos_embeds |
| img_hidden_states = self.embedding_dropout(img_hidden_states) |
|
|
| |
| if mask_ratio is not None: |
| img_hidden_states,_,_ = self.random_masking(img_hidden_states,mask_ratio) |
|
|
|
|
|
|
| |
| hidden_states = torch.cat([motion_token,img_hidden_states],dim=1) |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| hidden_states = block( |
| hidden_states=hidden_states, |
| ) |
|
|
| |
| motion_token = hidden_states[:, :motion_token.shape[1],:] |
| motion_token = self.norm_final(motion_token) |
| motion_token = self.proj_out(motion_token) |
| if self.need_norm_out: |
| motion_token = self.norm_out(motion_token) |
|
|
| |
| motion_token = einops.rearrange(motion_token, '(n t) l d -> n t l d',n=N) |
|
|
| return motion_token |
|
|
| def random_masking(self, x, mask_ratio): |
| """ |
| Perform per-sample random masking by per-sample shuffling. |
| Per-sample shuffling is done by argsort random noise. |
| x: [N, L, D], sequence |
| """ |
| N, L, D = x.shape |
| len_keep = int(L * (1 - mask_ratio)) |
| |
| noise = torch.rand(N, L, device=x.device) |
| |
| |
| |
| ids_shuffle = torch.argsort(noise, dim=1) |
| |
| ids_restore = torch.argsort(ids_shuffle, dim=1) |
|
|
| |
| ids_keep = ids_shuffle[:, :len_keep] |
| |
| x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) |
|
|
| |
| mask = torch.ones([N, L], device=x.device) |
| mask[:, :len_keep] = 0 |
| |
| mask = torch.gather(mask, dim=1, index=ids_restore) |
|
|
| return x_masked, mask, ids_restore |
|
|
|
|
| |
| class MotionTransformer(nn.Module): |
| """ |
| Motion |
| """ |
|
|
| _supports_gradient_checkpointing = True |
|
|
| def __init__( |
| self, |
| motion_token_num : int = 4, |
| motion_token_channel : int = 128, |
| motion_frames : int = 128, |
| attention_head_dim : int = 64, |
| num_attention_heads: int = 16, |
| num_layers: int = 8, |
| |
| freq_shift: int = 0, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| activation_fn: str = "gelu-approximate", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = motion_token_channel |
| self.motion_token_length = motion_token_num * motion_frames |
| |
| |
| self.embed = nn.Linear(motion_token_channel,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_token_length)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, eps=norm_eps, elementwise_affine=norm_elementwise_affine) |
|
|
| |
| self.proj_out = nn.Linear(hidden_dim,motion_token_channel) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| ): |
| N,F,L,D = hidden_states.shape |
|
|
| |
| hidden_states = self.embed(hidden_states) |
|
|
| |
| hidden_states = hidden_states.flatten(1,2) + self.motion_pos_embedding[:,:F*L,:] |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| hidden_states, |
| **ckpt_kwargs, |
| ) |
| else: |
| hidden_states = block( |
| hidden_states=hidden_states, |
| ) |
|
|
| hidden_states = self.norm_final(hidden_states) |
|
|
| |
| hidden_states = self.proj_out(hidden_states) |
|
|
| |
| hidden_states = einops.rearrange(hidden_states,'n (f l) d -> n f l d',f=F) |
|
|
| return hidden_states |
|
|
|
|
| |
| class AMDReconstructTransformerModel(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
| |
| |
| self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
| self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_source_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L + 2 |
|
|
| |
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
| image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
| |
| source_token = self.source_token.repeat(N, 1, 1) |
| target_token = self.target_token.repeat(N, 1, 1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| hidden_states = torch.cat([image_hidden_states,motion_hidden_states],dim=1) |
| for i, block in enumerate(self.transformer_blocks): |
| hidden_states = block( |
| hidden_states=hidden_states, |
| ) |
|
|
| image_hidden_states = self.norm_final(hidden_states[:,:image_hidden_states.shape[1],:]) |
|
|
| |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
|
|
| class AMDReconstructTransformerModelSpatial(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| motion_target_num_frame : int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
|
|
| self.target_frame = motion_target_num_frame |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame)) |
| img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False) |
| |
|
|
| |
| self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
| self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.spatial_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_source_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L + 2 |
|
|
| n = N // self.target_frame |
| t = self.target_frame |
| l = L |
| d = Cm |
|
|
| |
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
| image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
| |
| source_token = self.source_token.repeat(N, 1, 1) |
| target_token = self.target_token.repeat(N, 1, 1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,:t] |
| image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| |
| img_length = image_hidden_states.shape[1] |
| motion_length = motion_hidden_states.shape[1] |
| for block,s_block in zip(self.transformer_blocks,self.spatial_blocks): |
| hidden_states = torch.cat([image_hidden_states,motion_hidden_states],dim=1) |
| hidden_states = block( |
| hidden_states=hidden_states, |
| ) |
| image_hidden_states = hidden_states[:,:img_length,:] |
| motion_hidden_states = hidden_states[:,img_length:,:] |
|
|
| image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| image_hidden_states = s_block( |
| hidden_states = image_hidden_states, |
| ) |
|
|
| image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
| |
| image_hidden_states = self.norm_final(image_hidden_states) |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
|
|
|
|
| class AMDReconstructSplitTransformerModel(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
|
|
| |
| self.zi_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
| self.zt_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
| |
| |
| self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
| self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| BasicTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_source_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L + 2 |
|
|
| |
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
|
|
| zi_image_hidden_states = self.zi_image_patch_embed(image_hidden_states[:,:Ci//2,:,:]) |
| zt_image_hidden_states = self.zt_image_patch_embed(image_hidden_states[:,Ci//2:,:,:]) |
|
|
| |
| source_token = self.source_token.repeat(N, 1, 1) |
| target_token = self.target_token.repeat(N, 1, 1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
| zi_image_hidden_states = zi_image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| zt_image_hidden_states = zt_image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| hidden_states = torch.cat([zt_image_hidden_states,zi_image_hidden_states,motion_hidden_states],dim=1) |
| for i, block in enumerate(self.transformer_blocks): |
| hidden_states = block( |
| hidden_states=hidden_states, |
| ) |
|
|
| image_hidden_states = self.norm_final(hidden_states[:,:zt_image_hidden_states.shape[1],:]) |
|
|
| |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
|
|
|
|
| class AMDDiffusionTransformerModel(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
| |
| |
| self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
| self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| AMDTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=2 * hidden_dim, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_source_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L + 2 |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
| image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
| |
| source_token = self.source_token.repeat(N, 1, 1) |
| target_token = self.target_token.repeat(N, 1, 1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| motion_hidden_states, image_hidden_states = block( |
| hidden_states=motion_hidden_states, |
| encoder_hidden_states=image_hidden_states, |
| temb=emb, |
| ) |
|
|
| image_hidden_states = self.norm_final(image_hidden_states) |
|
|
| |
| image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
| class AMDDiffusionTransformerModelDualStream(nn.Module): |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| motion_target_num_frame : int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
| self.target_frame = motion_target_num_frame |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| INIT_CONST = 0.02 |
| self.source_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * INIT_CONST) |
| self.target_token = nn.Parameter(torch.randn(1, 1, hidden_dim) * INIT_CONST) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2*motion_token_num+2)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2*self.target_frame * (motion_token_num+1))) |
| motion_temporal_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_temporal_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_temporal_embedding",motion_temporal_embedding,persistent=False) |
|
|
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| AMDTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.motion_blocks = nn.ModuleList( |
| [ |
| AMDTransformerMotionBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=2 * hidden_dim, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (nt,l,d) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_target_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L |
|
|
| n = N // self.target_frame |
| t = self.target_frame |
| l = L |
| d = Cm |
|
|
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| emb_m = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t) |
| emb_m = emb_m[:,0,:] |
|
|
| |
| image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
| |
| source_token = self.source_token.repeat(n*t,1,1) |
| target_token = self.target_token.repeat(n*t,1,1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :2*l+2] |
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'(n t) l d -> n (t l) d',n=n) |
| motion_hidden_states = motion_hidden_states + self.motion_temporal_embedding[:,:t*(2*l+2)] |
|
|
| |
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
|
|
| |
| for block,m_block in zip(self.transformer_blocks,self.motion_blocks): |
|
|
| |
| motion_hidden_states = m_block( |
| hidden_states = motion_hidden_states, |
| temb = emb_m, |
| ) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n (t l) d -> (n t) l d',t=t) |
|
|
| |
| motion_hidden_states, image_hidden_states = block( |
| hidden_states=motion_hidden_states, |
| encoder_hidden_states=image_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states = einops.rearrange(motion_hidden_states,'(n t) l d -> n (t l) d',t=t) |
|
|
|
|
| image_hidden_states = self.norm_final(image_hidden_states) |
|
|
| |
| image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
| class AMDDiffusionTransformerModelImgSpatial(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| motion_target_num_frame : int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
| self.target_frame = motion_target_num_frame |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame)) |
| img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False) |
| |
| |
| self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
| self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| AMDTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.spatial_blocks = nn.ModuleList( |
| [ |
| BasicDiTBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=2 * hidden_dim, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_source_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L + 2 |
|
|
| n = N // self.target_frame |
| t = self.target_frame |
| l = L |
| d = Cm |
|
|
| |
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
| image_hidden_states = self.image_patch_embed(image_hidden_states) |
|
|
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| emb_s = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t) |
| emb_s = emb_s[:,:1,:].repeat(1,image_hidden_states.shape[1],1) |
| emb_s = emb_s.flatten(0,1) |
|
|
| |
| source_token = self.source_token.repeat(N, 1, 1) |
| target_token = self.target_token.repeat(N, 1, 1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,:t] |
| image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for block,s_block in zip(self.transformer_blocks,self.spatial_blocks): |
| motion_hidden_states, image_hidden_states = block( |
| hidden_states=motion_hidden_states, |
| encoder_hidden_states=image_hidden_states, |
| temb=emb, |
| ) |
|
|
| image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| image_hidden_states = s_block( |
| hidden_states = image_hidden_states, |
| temb = emb_s, |
| ) |
|
|
| image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
| image_hidden_states = self.norm_final(image_hidden_states) |
|
|
| |
| image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
| class AMDDiffusionTransformerModelImgSpatialDoubleRef(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 20, |
| attention_head_dim: int = 64, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 12, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| image_patch_size: int = 2, |
| image_in_channels: Optional[int] = 4, |
| |
| motion_token_num:int = 12, |
| motion_in_channels: Optional[int] = 128, |
| motion_target_num_frame : int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
| self.image_patch_size = image_patch_size |
| self.out_channels = out_channels |
| self.target_frame = motion_target_num_frame |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels*2,embed_dim=hidden_dim, bias=True) |
| self.lastref_image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = nn.Linear(motion_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(2+2*motion_token_num)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.target_frame + 1)) |
| img_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| img_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("img_temporal_embedding",img_pos_embedding,persistent=False) |
| |
| |
| self.source_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
| self.target_token = nn.Parameter(torch.zeros(1, 1, hidden_dim),requires_grad=True) |
|
|
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| AMDTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.spatial_blocks = nn.ModuleList( |
| [ |
| BasicDiTBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=2 * hidden_dim, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_source_hidden_states: torch.Tensor, |
| motion_target_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| randomref_image_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| randomref_image_hidden_states : (b,c,H,W) |
| time_step : (b,) |
| |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,L,Cm = motion_source_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = 2*L + 2 |
|
|
| n = N // self.target_frame |
| t = self.target_frame |
| l = L |
| d = Cm |
|
|
| |
| lastref_image_hidden_states = image_hidden_states[:,:Ci//2,:,:] |
| noised_image_hidden_states = image_hidden_states[:,Ci//2:,:,:] |
| randomref_image_hidden_states = randomref_image_hidden_states |
|
|
| |
| motion_source_hidden_states = self.motion_patch_embed(motion_source_hidden_states) |
| motion_target_hidden_states = self.motion_patch_embed(motion_target_hidden_states) |
| image_hidden_states = self.image_patch_embed(torch.cat([randomref_image_hidden_states,noised_image_hidden_states],dim=1)) |
| lastref_image_hidden_states = self.lastref_image_patch_embed(lastref_image_hidden_states) |
|
|
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_source_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| emb_s = einops.rearrange(emb,'(n t) d -> n t d',n=n,t=t) |
| emb_s = emb_s[:,:1,:].repeat(1,image_hidden_states.shape[1],1) |
| emb_s = emb_s.flatten(0,1) |
|
|
| |
| source_token = self.source_token.repeat(N, 1, 1) |
| target_token = self.target_token.repeat(N, 1, 1) |
| motion_hidden_states = torch.cat([source_token,motion_source_hidden_states,target_token,motion_target_hidden_states],dim=1) |
|
|
| |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length] |
|
|
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| image_hidden_states = image_hidden_states + self.img_temporal_embedding[:,1:t+1] |
| image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
| lastref_image_hidden_states = lastref_image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| lastref_image_hidden_states = einops.rearrange(lastref_image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| lastref_image_hidden_states = lastref_image_hidden_states[:,:1,:] |
| lastref_image_hidden_states = lastref_image_hidden_states + self.img_temporal_embedding[:,:1] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for block,s_block in zip(self.transformer_blocks,self.spatial_blocks): |
|
|
| motion_hidden_states, image_hidden_states = block( |
| hidden_states=motion_hidden_states, |
| encoder_hidden_states=image_hidden_states, |
| temb=emb, |
| ) |
|
|
| image_hidden_states = einops.rearrange(image_hidden_states,'(n t) s d -> (n s) t d',n=n) |
| s_block_input_hidden_states = torch.cat([lastref_image_hidden_states,image_hidden_states],dim=1) |
|
|
| s_block_output = s_block( |
| hidden_states = s_block_input_hidden_states, |
| temb = emb_s, |
| ) |
| |
| lastref_image_hidden_states = s_block_output[:,:1,:] |
| image_hidden_states = s_block_output[:,1:,:] |
|
|
| image_hidden_states = einops.rearrange(image_hidden_states,'(n s) t d -> (n t) s d',n=n) |
|
|
|
|
| image_hidden_states = self.norm_final(image_hidden_states) |
|
|
| |
| image_hidden_states = self.norm_out(image_hidden_states, temb=emb) |
| image_hidden_states = self.proj_out(image_hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = image_hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
|
|
| class AMDDiffusionTransformerModelSplitInput(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 30, |
| attention_head_dim: int = 64, |
| image_in_channels: Optional[int] = 4, |
| motion_in_channels: Optional[int] = 16, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 16, |
| |
| image_width: int = 64, |
| image_height: int = 64, |
| motion_width: int = 8, |
| motion_height: int = 8, |
| image_patch_size: int = 2, |
| motion_patch_size: int = 1, |
| motion_frames: int = 15, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = 2*iph * ipw |
| mph = motion_height // motion_patch_size |
| mpw = motion_width // motion_patch_size |
| mtl = mph * mpw * motion_frames |
| self.max_seq_length = itl + mtl |
| |
| self.image_patch_size = image_patch_size |
| self.motion_patch_size = motion_patch_size |
| self.out_channels = out_channels |
|
|
| |
| self.zi_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
| self.zt_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels//2,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = PatchEmbed(patch_size=motion_patch_size,in_channels=motion_in_channels,embed_dim=hidden_dim, bias=True) |
|
|
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| spatial_pos_embedding = get_3d_sincos_pos_embed( |
| hidden_dim, |
| (iph, ipw), |
| 2, |
| spatial_interpolation_scale, |
| temporal_interpolation_scale, |
| ) |
| spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
| |
| |
| pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
| pos_embedding.data.copy_(spatial_pos_embedding) |
| self.register_buffer("pos_embedding", pos_embedding, persistent=False) |
| |
| |
| |
|
|
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| AMDTransformerBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=2 * hidden_dim, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| image_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (b,d,h,w) |
| image_hidden_states : (b,2c,H,W) |
| time_step : (b,) |
| |
| """ |
| N,Ci,Hi,Wi = image_hidden_states.shape |
| N,Cm,Hm,Wm = motion_hidden_states.shape |
| image_seq_length = 2 * Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = Hm * Wm // (self.motion_patch_size**2) |
|
|
| zi = image_hidden_states[:,:Ci//2] |
| zt = image_hidden_states[:,Ci//2:] |
|
|
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| zi_hidden_states = self.zi_patch_embed(zi) |
| zt_hidden_states = self.zt_patch_embed(zt) |
| image_hidden_states = torch.cat((zi_hidden_states,zt_hidden_states),dim=1) |
|
|
| assert image_seq_length == image_hidden_states.shape[1] , f"image_seq_length : {image_seq_length} != image_hidden_states.shape[1] : {image_hidden_states.shape[1]}" |
|
|
| |
| image_hidden_states = image_hidden_states + self.pos_embedding[:, :image_seq_length] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| motion_hidden_states, image_hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| motion_hidden_states, |
| image_hidden_states, |
| emb, |
| **ckpt_kwargs, |
| ) |
| else: |
| motion_hidden_states, image_hidden_states = block( |
| hidden_states=motion_hidden_states, |
| encoder_hidden_states=image_hidden_states, |
| temb=emb, |
| ) |
|
|
| pre = image_hidden_states[:,image_seq_length//2:] |
| pre = self.norm_final(pre) |
|
|
| |
| pre = self.norm_out(pre, temb=emb) |
| pre = self.proj_out(pre) |
|
|
| |
| p = self.image_patch_size |
| output = pre.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
|
|
| class DiffusionTransformerModel2Condition(nn.Module): |
| """ |
| Diffusion Transformer |
| """ |
|
|
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| num_attention_heads: int = 30, |
| attention_head_dim: int = 64, |
| image_in_channels: Optional[int] = 4, |
| motion_in_channels: Optional[int] = 16, |
| out_channels: Optional[int] = 4, |
| num_layers: int = 16, |
| |
| image_width: int = 32, |
| image_height: int = 32, |
| motion_width: int = 8, |
| motion_height: int = 8, |
| image_patch_size: int = 2, |
| motion_patch_size: int = 1, |
| motion_frames: int = 15, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| """ |
| |
| Traning: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| |
| Inference: |
| Z(N,1,C,H,W) |
| Motion(N,k,d,h,w) |
| """ |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| iph = image_height // image_patch_size |
| ipw = image_width // image_patch_size |
| itl = iph * ipw |
|
|
| mph = motion_height // motion_patch_size |
| mpw = motion_width // motion_patch_size |
| mtl = mph * mpw * motion_frames |
|
|
| self.max_seq_length = 2 * itl + mtl |
| |
| self.image_patch_size = image_patch_size |
| self.motion_patch_size = motion_patch_size |
| self.out_channels = out_channels |
|
|
| |
| self.image_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.refimg_patch_embed = PatchEmbed(patch_size=image_patch_size,in_channels=image_in_channels,embed_dim=hidden_dim, bias=True) |
| self.motion_patch_embed = PatchEmbed(patch_size=motion_patch_size,in_channels=motion_in_channels,embed_dim=hidden_dim, bias=True) |
| |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| spatial_pos_embedding = get_3d_sincos_pos_embed( |
| hidden_dim, |
| (iph, iph), |
| 2, |
| spatial_interpolation_scale, |
| temporal_interpolation_scale, |
| ) |
| spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
| pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
| pos_embedding.data.copy_(spatial_pos_embedding) |
| self.register_buffer("img_pos_embedding", pos_embedding, persistent=False) |
|
|
| |
| spatial_pos_embedding = get_3d_sincos_pos_embed( |
| hidden_dim, |
| (mph, mph), |
| motion_frames, |
| spatial_interpolation_scale, |
| temporal_interpolation_scale, |
| ) |
| spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
| pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
| pos_embedding.data.copy_(spatial_pos_embedding) |
| self.register_buffer("motion_pos_embedding", pos_embedding, persistent=False) |
| |
|
|
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| TransformerBlock2Condition( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=2 * hidden_dim, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, image_patch_size * image_patch_size * out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| refimg_hidden_states: torch.Tensor, |
| motion_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| hidden_states : (b,c,H,W) |
| motion_hidden_states : (b,d,h,w) |
| refimg_hidden_states : (b,c,H,W) |
| time_step : (b,) |
| |
| """ |
|
|
| N,Ci,Hi,Wi = hidden_states.shape |
| N,Cm,Hm,Wm = motion_hidden_states.shape |
| image_seq_length = Hi * Wi // (self.image_patch_size**2) |
| motion_seq_length = Hm * Wm // (self.motion_patch_size**2) |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| hidden_states = self.image_patch_embed(hidden_states) |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| refimg_hidden_states = self.refimg_patch_embed(refimg_hidden_states) |
|
|
| |
| hidden_states = hidden_states + self.img_pos_embedding[:,:image_seq_length,:] |
| refimg_hidden_states = refimg_hidden_states + self.img_pos_embedding[:, image_seq_length:,:] |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:, :motion_seq_length,:] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| motion_hidden_states, refimg_hidden_states,motion_hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| hidden_states, |
| refimg_hidden_states, |
| motion_hidden_states, |
| emb, |
| **ckpt_kwargs, |
| ) |
| else: |
| hidden_states, refimg_hidden_states,motion_hidden_states = block( |
| hidden_states, |
| refimg_hidden_states, |
| motion_hidden_states, |
| emb, |
| ) |
|
|
| hidden_states = self.norm_final(hidden_states) |
|
|
| |
| hidden_states = self.norm_out(hidden_states, temb=emb) |
| hidden_states = self.proj_out(hidden_states) |
|
|
| |
| p = self.image_patch_size |
| output = hidden_states.reshape(N, 1, Hi // p, Wi // p, self.out_channels, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4).squeeze(1) |
|
|
| return output |
|
|
|
|
|
|
| |
|
|
| |
| class AudioMitionref_LearnableToken(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| motion_num_token: int = 12, |
| motion_inchannel: int = 128, |
| motion_frames: int = 128, |
| |
| extra_in_channels: Optional[int] = 768, |
| out_channels: Optional[int] = 128, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = out_channels |
| self.motion_frames = motion_frames |
| self.motion_num_token = motion_num_token |
| self.motion_num_tokens = motion_num_token * motion_frames |
|
|
| |
| self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.extra_embed = nn.Linear(extra_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_token + self.motion_num_tokens)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_frames)) |
| audio_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| audio_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("audio_pos_embedding",audio_pos_embedding,persistent=False) |
| |
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| TransformerBlock2Condition( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=hidden_dim*2, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| refmotion_hidden_states: torch.Tensor, |
| extra_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ): |
| """ |
| motion_hidden_states : (N,F,L,D) |
| refmotion_hidden_states : (N,L,D) |
| pose_hidden_states: (N,C,H,W) |
| extra_hidden_states : (N,F,D) 这个需要提前做 position encoding if needed ,前面两个固定在这里做position encoding |
| """ |
| assert motion_hidden_states.shape[1] == extra_hidden_states.shape[1],f"motion {motion_hidden_states.shape} ,audio{extra_hidden_states.shape}" |
|
|
| N,L,D = refmotion_hidden_states.shape |
| N,F,L,D = motion_hidden_states.shape |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| ref_motion_hidden_states = self.refmotion_patch_embed(refmotion_hidden_states) |
| extra_hidden_states = self.extra_embed(extra_hidden_states) |
|
|
| |
| ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:L,:] |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,L:,:] |
| extra_hidden_states = extra_hidden_states + self.audio_pos_embedding |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| extra_hidden_states, |
| emb, |
| **ckpt_kwargs, |
| ) |
| else: |
| motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| extra_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
| |
| motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
| motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
| |
| output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
| |
| return output |
|
|
|
|
| |
| class AudioMitionref_LearnableToken_SimpleAdaLN(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| motion_num_token: int = 12, |
| motion_inchannel: int = 128, |
| motion_frames: int = 128, |
| |
| extra_in_channels: Optional[int] = 768, |
| out_channels: Optional[int] = 128, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = out_channels |
| self.motion_frames = motion_frames |
| self.motion_num_token = motion_num_token |
| self.motion_num_tokens = motion_num_token * motion_frames |
|
|
| |
| self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.extra_embed = nn.Linear(extra_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_token + self.motion_num_tokens)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_frames)) |
| audio_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| audio_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("audio_pos_embedding",audio_pos_embedding,persistent=False) |
| |
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.transformer_blocks = nn.ModuleList( |
| [ |
| TransformerBlock2Condition_SimpleAdaLN( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=hidden_dim*2, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| refmotion_hidden_states: torch.Tensor, |
| extra_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| ): |
| """ |
| motion_hidden_states : (N,F,L,D) |
| refmotion_hidden_states : (N,L,D) |
| pose_hidden_states: (N,C,H,W) |
| extra_hidden_states : (N,F,D) 这个需要提前做 position encoding if needed ,前面两个固定在这里做position encoding |
| """ |
| assert motion_hidden_states.shape[1] == extra_hidden_states.shape[1],f"motion {motion_hidden_states.shape} ,audio{extra_hidden_states.shape}" |
|
|
| N,L,D = refmotion_hidden_states.shape |
| N,F,L,D = motion_hidden_states.shape |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| ref_motion_hidden_states = self.refmotion_patch_embed(refmotion_hidden_states) |
| extra_hidden_states = self.extra_embed(extra_hidden_states) |
|
|
| |
| ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:L,:] |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,L:,:] |
| extra_hidden_states = extra_hidden_states + self.audio_pos_embedding |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for i, block in enumerate(self.transformer_blocks): |
| if self.training and self.gradient_checkpointing: |
|
|
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
|
|
| return custom_forward |
|
|
| ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} |
| motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(block), |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| extra_hidden_states, |
| emb, |
| **ckpt_kwargs, |
| ) |
| else: |
| motion_hidden_states,ref_motion_hidden_states,extra_hidden_states = block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| extra_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
| |
| motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
| motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
| |
| output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
| |
| return output |
|
|
|
|
| |
| class A2MTransformer_CrossAttn_Audio(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| motion_num_token: int = 12, |
| motion_inchannel: int = 128, |
| motion_frames: int = 128, |
| |
| audio_window : Optional[int] = 12, |
| audio_in_channels: Optional[int] = 128, |
| out_channels: Optional[int] = 128, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = out_channels |
| self.motion_frames = motion_frames |
| self.motion_num_token = motion_num_token |
| self.motion_num_tokens = motion_num_token * motion_frames |
|
|
| |
| self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
| |
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.motion_blocks = nn.ModuleList( |
| [ |
| A2MMotionSelfAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.audio_blocks = nn.ModuleList( |
| [ |
| A2MCrossAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=hidden_dim*2, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| refmotion_hidden_states: torch.Tensor, |
| audio_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (N,F,L,D) |
| refmotion_hidden_states : (N,T,L,D) |
| audio_hidden_states : (N,T+F,W,D) |
| """ |
|
|
| N,T,L,D = refmotion_hidden_states.shape |
| N,F,L,D = motion_hidden_states.shape |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d') |
| ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
| audio_hidden_states = self.audio_embed(audio_hidden_states) |
|
|
| |
| ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for motion_block,audio_block in zip(self.motion_blocks,self.audio_blocks): |
| motion_hidden_states,ref_motion_hidden_states = motion_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| temb=emb, |
| ) |
|
|
|
|
| motion_hidden_states,ref_motion_hidden_states = audio_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| audio_hidden_states, |
| temb=emb, |
| ) |
| |
|
|
| motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
| |
| motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
| motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
| |
| output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
| |
| return output |
|
|
| class A2MTransformer_CrossAttn_Audio_DoubleRef(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| motion_num_token: int = 12, |
| motion_inchannel: int = 128, |
| motion_frames: int = 128, |
| |
| audio_window : Optional[int] = 12, |
| audio_in_channels: Optional[int] = 128, |
| out_channels: Optional[int] = 128, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = out_channels |
| self.motion_frames = motion_frames |
| self.motion_num_token = motion_num_token |
| self.motion_num_tokens = motion_num_token * motion_frames |
|
|
| |
| self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
| |
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.motion_blocks = nn.ModuleList( |
| [ |
| A2MMotionSelfAttnBlockDoubleRef( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| last_layer = (i==num_layers-1), |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for i in range(num_layers) |
| ] |
| ) |
|
|
| self.audio_blocks = nn.ModuleList( |
| [ |
| A2MCrossAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=hidden_dim*2, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| refmotion_hidden_states: torch.Tensor, |
| randomrefmotion_hidden_states: torch.Tensor, |
| audio_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (N,F,L,D) |
| refmotion_hidden_states : (N,T,L,D) |
| randomrefmotion_hidden_states : (N,S,L,D) |
| audio_hidden_states : (N,T+F,W,D) |
| """ |
|
|
| N,T,L,D = refmotion_hidden_states.shape |
| N,F,L,D = motion_hidden_states.shape |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d') |
| ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
| randomref_motion_hidden_states = einops.rearrange(randomrefmotion_hidden_states,'n s l d -> n (s l) d') |
| randomref_motion_hidden_states = self.refmotion_patch_embed(randomref_motion_hidden_states) |
| audio_hidden_states = self.audio_embed(audio_hidden_states) |
|
|
| |
| ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
| randomref_motion_hidden_states = randomref_motion_hidden_states.repeat_interleave(F, dim=0) |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for motion_block,audio_block in zip(self.motion_blocks,self.audio_blocks): |
| motion_hidden_states,ref_motion_hidden_states,randomref_motion_hidden_states = motion_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| randomref_motion_hidden_states, |
| temb=emb, |
| motion_token = L, |
| ) |
|
|
|
|
| motion_hidden_states,ref_motion_hidden_states = audio_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| audio_hidden_states, |
| temb=emb, |
| ) |
| |
|
|
| motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
| |
| motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
| motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
| |
| output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
| |
| return output |
|
|
|
|
| |
| class A2MTransformer_CrossAttn_Audio_Pose(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| motion_num_token: int = 12, |
| motion_inchannel: int = 128, |
| motion_frames: int = 128, |
| |
| audio_window : Optional[int] = 12, |
| audio_in_channels: Optional[int] = 128, |
| out_channels: Optional[int] = 128, |
| |
| pose_height : int = 32, |
| pose_width : int = 32, |
| pose_inchannel : int = 4, |
| pose_patch_size : int = 2, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = out_channels |
| self.motion_frames = motion_frames |
| self.motion_num_token = motion_num_token |
| self.motion_num_tokens = motion_num_token * motion_frames |
|
|
| |
| self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
| self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange( self.motion_num_tokens)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| |
| iph = pose_height // pose_patch_size |
| ipw = pose_width // pose_patch_size |
| itl = iph * ipw |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False) |
| |
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.motion_blocks = nn.ModuleList( |
| [ |
| A2MMotionSelfAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.audio_blocks = nn.ModuleList( |
| [ |
| A2MCrossAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.pose_blocks = nn.ModuleList( |
| [ |
| A2MCrossAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=hidden_dim*2, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| refmotion_hidden_states: torch.Tensor, |
| audio_hidden_states: torch.Tensor, |
| pose_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (N,F,L,D) |
| refmotion_hidden_states : (N,L,D) |
| audio_hidden_states : (N,F+1,W,D) |
| pose_hidden_states : (N,F+1,c,h,w) |
| """ |
|
|
| N,T,L,D = refmotion_hidden_states.shape |
| N,F,L,D = motion_hidden_states.shape |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n f l d -> n (f l) d') |
| ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
|
|
| audio_hidden_states = self.audio_embed(audio_hidden_states) |
| pose_hidden_states = self.pose_embed(pose_hidden_states.flatten(0,1)) |
|
|
| |
| ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
|
|
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
|
|
| pose_hidden_states = pose_hidden_states + self.pose_pos_embedding |
| pose_hidden_states = einops.rearrange(pose_hidden_states,'(n f) l d -> n f l d',n=N) |
| self.embedding_dropout(motion_hidden_states) |
|
|
| |
| for motion_block,audio_block,pose_block in zip(self.motion_blocks,self.audio_blocks,self.pose_blocks): |
|
|
| motion_hidden_states,ref_motion_hidden_states = motion_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states,ref_motion_hidden_states = audio_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| audio_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states,ref_motion_hidden_states = pose_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| pose_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
| |
| motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
| motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
| |
| output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
| |
| return output |
|
|
|
|
| |
| class A2MTransformer_CrossAttn_Pose(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| motion_num_token: int = 12, |
| motion_inchannel: int = 128, |
| motion_frames: int = 128, |
| |
| out_channels: Optional[int] = 128, |
| |
| pose_height : int = 32, |
| pose_width : int = 32, |
| pose_inchannel : int = 4, |
| pose_patch_size : int = 2, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channels = out_channels |
| self.motion_frames = motion_frames |
| self.motion_num_token = motion_num_token |
| self.motion_num_tokens = motion_num_token * motion_frames |
|
|
| |
| self.refmotion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.motion_patch_embed = nn.Linear(motion_inchannel,hidden_dim) |
| self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True) |
| self.embedding_dropout = nn.Dropout(dropout) |
|
|
| |
| temporal_embedding = get_1d_sincos_pos_embed_from_grid(hidden_dim,torch.arange(self.motion_num_tokens)) |
| motion_pos_embedding = torch.zeros(1,*temporal_embedding.shape,requires_grad=False) |
| motion_pos_embedding.data.copy_(torch.from_numpy(temporal_embedding)) |
| self.register_buffer("motion_pos_embedding",motion_pos_embedding,persistent=False) |
|
|
| |
| iph = pose_height // pose_patch_size |
| ipw = pose_width // pose_patch_size |
| itl = iph * ipw |
| image_pos_embedding = get_2d_sincos_pos_embed(hidden_dim, (iph, ipw)) |
| image_pos_embedding = torch.from_numpy(image_pos_embedding) |
| pos_embedding = torch.zeros(1, itl, hidden_dim, requires_grad=False) |
| pos_embedding.data[:, :itl].copy_(image_pos_embedding) |
| self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False) |
| |
| |
| self.time_proj = Timesteps(hidden_dim, flip_sin_to_cos, freq_shift) |
| self.time_embedding = TimestepEmbedding(hidden_dim, time_embed_dim, timestep_activation_fn) |
|
|
| |
| self.motion_blocks = nn.ModuleList( |
| [ |
| A2MMotionSelfAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.pose_blocks = nn.ModuleList( |
| [ |
| A2MCrossAttnBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| time_embed_dim=time_embed_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.norm_out = AdaLayerNorm( |
| embedding_dim=time_embed_dim, |
| output_dim=hidden_dim*2, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| chunk_dim=1, |
| ) |
| self.proj_out = nn.Linear(hidden_dim, out_channels) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| motion_hidden_states: torch.Tensor, |
| refmotion_hidden_states: torch.Tensor, |
| pose_hidden_states: torch.Tensor, |
| timestep: Union[int, float, torch.LongTensor], |
| timestep_cond: Optional[torch.Tensor] = None, |
| return_dict: bool = True, |
| **kwargs, |
| ): |
| """ |
| motion_hidden_states : (N,F,L,D) |
| refmotion_hidden_states : (N,T,L,D) |
| pose_hidden_states : (N,T+F,c,h,w) |
| """ |
|
|
|
|
| N,T,L,D = refmotion_hidden_states.shape |
| N,F,L,D = motion_hidden_states.shape |
| |
| |
| t_emb = self.time_proj(timestep) |
| t_emb = t_emb.to(dtype=motion_hidden_states.dtype) |
| emb = self.time_embedding(t_emb, timestep_cond) |
|
|
| |
| motion_hidden_states = einops.rearrange(motion_hidden_states,'n f l d -> n (f l) d') |
| motion_hidden_states = self.motion_patch_embed(motion_hidden_states) |
| ref_motion_hidden_states = einops.rearrange(refmotion_hidden_states,'n t l d -> n (t l) d') |
| ref_motion_hidden_states = self.refmotion_patch_embed(ref_motion_hidden_states) |
| pose_hidden_states = self.pose_embed(pose_hidden_states.flatten(0,1)) |
|
|
| |
| ref_motion_hidden_states = ref_motion_hidden_states + self.motion_pos_embedding[:,:ref_motion_hidden_states.shape[1],:] |
| motion_hidden_states = motion_hidden_states + self.motion_pos_embedding[:,:motion_hidden_states.shape[1],:] |
| self.embedding_dropout(motion_hidden_states) |
| pose_hidden_states = pose_hidden_states + self.pose_pos_embedding |
| pose_hidden_states = einops.rearrange(pose_hidden_states,'(n f) l d -> n f l d',n=N) |
|
|
| |
| for motion_block,pose_block in zip(self.motion_blocks,self.pose_blocks): |
|
|
| motion_hidden_states,ref_motion_hidden_states = motion_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states,ref_motion_hidden_states = pose_block( |
| motion_hidden_states, |
| ref_motion_hidden_states, |
| pose_hidden_states, |
| temb=emb, |
| ) |
|
|
| motion_hidden_states = self.norm_final(motion_hidden_states) |
|
|
| |
| motion_hidden_states = self.norm_out(motion_hidden_states, temb=emb) |
| motion_hidden_states = self.proj_out(motion_hidden_states) |
|
|
| |
| output = einops.rearrange(motion_hidden_states,'n (f l) d -> n f l d',f=F) |
| |
| return output |
|
|
|
|
| |
|
|
| class A2PTransformer(nn.Module): |
| _supports_gradient_checkpointing = True |
| def __init__( |
| self, |
| |
| audio_window : Optional[int] = 12, |
| audio_in_channels: Optional[int] = 128, |
| |
| pose_height : int = 32, |
| pose_width : int = 32, |
| pose_inchannel : int = 4, |
| pose_patch_size : int = 4, |
| pose_frame : int = 17, |
| |
| num_attention_heads: int = 8, |
| attention_head_dim: int = 64, |
| num_layers: int = 16, |
| |
| flip_sin_to_cos: bool = True, |
| freq_shift: int = 0, |
| time_embed_dim: int = 512, |
| |
| dropout: float = 0.0, |
| attention_bias: bool = True, |
| temporal_compression_ratio: int = 4, |
| |
| activation_fn: str = "gelu-approximate", |
| timestep_activation_fn: str = "silu", |
| norm_elementwise_affine: bool = True, |
| norm_eps: float = 1e-5, |
| spatial_interpolation_scale: float = 1.0, |
| temporal_interpolation_scale: float = 1.0, |
| ): |
| super().__init__() |
| |
| |
| hidden_dim = num_attention_heads * attention_head_dim |
| self.out_channel = pose_inchannel |
| self.pose_patch_size = pose_patch_size |
|
|
| |
| self.pose_embed = PatchEmbed(patch_size=pose_patch_size,in_channels=pose_inchannel,embed_dim=hidden_dim, bias=True) |
| self.audio_embed = nn.Linear(audio_in_channels,hidden_dim) |
|
|
| |
| iph = pose_height // pose_patch_size |
| ipw = pose_width // pose_patch_size |
| itl = iph * ipw |
| INIT_CONST = 0.02 |
| self.pose_mask_token = nn.Parameter(torch.randn(1, itl, hidden_dim) * INIT_CONST) |
|
|
| |
| spatial_pos_embedding = get_3d_sincos_pos_embed( |
| hidden_dim, |
| (iph, ipw), |
| pose_frame, |
| spatial_interpolation_scale, |
| temporal_interpolation_scale, |
| ) |
| spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1) |
| pos_embedding = torch.zeros(1,*spatial_pos_embedding.shape, requires_grad=False) |
| pos_embedding.data.copy_(spatial_pos_embedding) |
| self.register_buffer("pose_pos_embedding", pos_embedding, persistent=False) |
| |
|
|
| |
| self.temporal_spatial_blocks = nn.ModuleList( |
| [ |
| A2PTemporalSpatialBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.audio_blocks = nn.ModuleList( |
| [ |
| A2PCrossAudioBlock( |
| dim=hidden_dim, |
| num_attention_heads=num_attention_heads, |
| attention_head_dim=attention_head_dim, |
| dropout=dropout, |
| activation_fn=activation_fn, |
| attention_bias=attention_bias, |
| norm_elementwise_affine=norm_elementwise_affine, |
| norm_eps=norm_eps, |
| ) |
| for _ in range(num_layers) |
| ] |
| ) |
|
|
| self.norm_final = nn.LayerNorm(hidden_dim, norm_eps, norm_elementwise_affine) |
|
|
| |
| self.proj_out = nn.Linear(hidden_dim, pose_patch_size * pose_patch_size * pose_inchannel) |
|
|
| self.gradient_checkpointing = False |
|
|
| def _set_gradient_checkpointing(self, module, value=False): |
| self.gradient_checkpointing = value |
|
|
| def forward( |
| self, |
| ref_pose_hidden_states: torch.Tensor, |
| audio_hidden_states: torch.Tensor, |
| **kwargs, |
| ): |
| """ |
| ref_pose_hidden_states : (N,T,C,H,W) |
| audio_hidden_states : (N,T+F,W,D) |
| """ |
|
|
| N,T,C,H,W = ref_pose_hidden_states.shape |
| N,T_F,L,D = audio_hidden_states.shape |
| F = T_F - T |
|
|
| |
| audio_hidden_states = einops.rearrange(audio_hidden_states,'n f w d -> (n f) w d') |
| audio_hidden_states = self.audio_embed(audio_hidden_states) |
| audio_hidden_states = einops.rearrange(audio_hidden_states,'(n f) w d -> n f w d',n=N) |
| |
| |
| ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'n t c h w -> (n t) c h w') |
| ref_pose_hidden_states = self.pose_embed(ref_pose_hidden_states) |
| ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'(n t) l d -> n t l d',n=N) |
| pose_mask_hidden_states = self.pose_mask_token.unsqueeze(0).repeat(N,F,1,1) |
|
|
| |
| ref_pose_hidden_states = ref_pose_hidden_states.flatten(1,2) |
| ref_pose_hidden_states = ref_pose_hidden_states + self.pose_pos_embedding[:,:ref_pose_hidden_states.shape[1],:] |
| ref_pose_hidden_states = einops.rearrange(ref_pose_hidden_states,'n (t l) d -> n t l d',t=T) |
|
|
| pose_mask_hidden_states = pose_mask_hidden_states.flatten(1,2) |
| pose_mask_hidden_states = pose_mask_hidden_states + self.pose_pos_embedding[:,:pose_mask_hidden_states.shape[1],:] |
| pose_mask_hidden_states = einops.rearrange(pose_mask_hidden_states,'n (f l) d -> n f l d',f=F) |
| pose_hidden_states = torch.cat((ref_pose_hidden_states,pose_mask_hidden_states),dim=1) |
|
|
| |
| for st_block,audio_block in zip(self.temporal_spatial_blocks,self.audio_blocks): |
|
|
| pose_hidden_states = st_block( |
| pose_hidden_states, |
| ) |
|
|
| pose_hidden_states = audio_block( |
| pose_hidden_states, |
| audio_hidden_states, |
| ) |
|
|
| pose_hidden_states = self.norm_final(pose_hidden_states) |
|
|
| |
| pose_hidden_states = self.proj_out(pose_hidden_states) |
|
|
| |
| p = self.pose_patch_size |
| output = pose_hidden_states.reshape(N, T_F, H // p, W // p, self.out_channel, p, p) |
| output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) |
| |
| return output |
|
|