import torch from torch import nn from .transformer import ( PositionalEncoding, TransformerDecoderLayer, TransformerDecoder, ) from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder from core.utils import _reset_parameters def get_decoder_network( network_type, d_model, nhead, dim_feedforward, dropout, activation, normalize_before, num_decoder_layers, return_intermediate_dec, dynamic_K, dynamic_ratio, ): decoder = None if network_type == "TransformerDecoder": decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) norm = nn.LayerNorm(d_model) decoder = TransformerDecoder( decoder_layer, num_decoder_layers, norm, return_intermediate_dec, ) elif network_type == "DynamicFCDecoder": d_style = d_model decoder_layer = DynamicFCDecoderLayer( d_model, nhead, d_style, dynamic_K, dynamic_ratio, dim_feedforward, dropout, activation, normalize_before, ) norm = nn.LayerNorm(d_model) decoder = DynamicFCDecoder(decoder_layer, num_decoder_layers, norm, return_intermediate_dec) else: raise ValueError(f"Invalid network_type {network_type}") return decoder class DisentangleDecoder(nn.Module): def __init__( self, d_model=512, nhead=8, num_decoder_layers=3, dim_feedforward=2048, dropout=0.1, activation="relu", normalize_before=False, return_intermediate_dec=False, pos_embed_len=80, upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))), lower_face3d_indices=tuple(range(19, 46)), network_type="None", dynamic_K=None, dynamic_ratio=None, **_, ) -> None: super().__init__() self.upper_face3d_indices = upper_face3d_indices self.lower_face3d_indices = lower_face3d_indices # upper_decoder_layer = TransformerDecoderLayer( # d_model, nhead, dim_feedforward, dropout, activation, normalize_before # ) # upper_decoder_norm = nn.LayerNorm(d_model) # self.upper_decoder = TransformerDecoder( # upper_decoder_layer, # num_decoder_layers, # upper_decoder_norm, # return_intermediate=return_intermediate_dec, # ) self.upper_decoder = get_decoder_network( network_type, d_model, nhead, dim_feedforward, dropout, activation, normalize_before, num_decoder_layers, return_intermediate_dec, dynamic_K, dynamic_ratio, ) _reset_parameters(self.upper_decoder) # lower_decoder_layer = TransformerDecoderLayer( # d_model, nhead, dim_feedforward, dropout, activation, normalize_before # ) # lower_decoder_norm = nn.LayerNorm(d_model) # self.lower_decoder = TransformerDecoder( # lower_decoder_layer, # num_decoder_layers, # lower_decoder_norm, # return_intermediate=return_intermediate_dec, # ) self.lower_decoder = get_decoder_network( network_type, d_model, nhead, dim_feedforward, dropout, activation, normalize_before, num_decoder_layers, return_intermediate_dec, dynamic_K, dynamic_ratio, ) _reset_parameters(self.lower_decoder) self.pos_embed = PositionalEncoding(d_model, pos_embed_len) tail_hidden_dim = d_model // 2 self.upper_tail_fc = nn.Sequential( nn.Linear(d_model, tail_hidden_dim), nn.ReLU(), nn.Linear(tail_hidden_dim, tail_hidden_dim), nn.ReLU(), nn.Linear(tail_hidden_dim, len(upper_face3d_indices)), ) self.lower_tail_fc = nn.Sequential( nn.Linear(d_model, tail_hidden_dim), nn.ReLU(), nn.Linear(tail_hidden_dim, tail_hidden_dim), nn.ReLU(), nn.Linear(tail_hidden_dim, len(lower_face3d_indices)), ) def forward(self, content, style_code): """ Args: content (_type_): (B, num_frames, window, C_dmodel) style_code (_type_): (B, C_dmodel) Returns: face3d: (B, L_clip, C_3dmm) """ B, N, W, C = content.shape style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C) style = style.permute(2, 0, 1, 3).reshape(W, B * N, C) # (W, B*N, C) content = content.permute(2, 0, 1, 3).reshape(W, B * N, C) # (W, B*N, C) tgt = torch.zeros_like(style) pos_embed = self.pos_embed(W) pos_embed = pos_embed.permute(1, 0, 2) upper_face3d_feat = self.upper_decoder(tgt, content, pos=pos_embed, query_pos=style)[0] # (W, B*N, C) upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :] # (B, N, C) upper_face3d = self.upper_tail_fc(upper_face3d_feat) # (B, N, C_exp) lower_face3d_feat = self.lower_decoder(tgt, content, pos=pos_embed, query_pos=style)[0] lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :] lower_face3d = self.lower_tail_fc(lower_face3d_feat) C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices) face3d = torch.zeros(B, N, C_exp).to(upper_face3d) face3d[:, :, self.upper_face3d_indices] = upper_face3d face3d[:, :, self.lower_face3d_indices] = lower_face3d return face3d