|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from .transformer import ( |
|
|
PositionalEncoding, |
|
|
TransformerDecoderLayer, |
|
|
TransformerDecoder, |
|
|
) |
|
|
from core.networks.dynamic_fc_decoder import DynamicFCDecoderLayer, DynamicFCDecoder |
|
|
from core.utils import _reset_parameters |
|
|
|
|
|
|
|
|
def get_decoder_network( |
|
|
network_type, |
|
|
d_model, |
|
|
nhead, |
|
|
dim_feedforward, |
|
|
dropout, |
|
|
activation, |
|
|
normalize_before, |
|
|
num_decoder_layers, |
|
|
return_intermediate_dec, |
|
|
dynamic_K, |
|
|
dynamic_ratio, |
|
|
): |
|
|
decoder = None |
|
|
if network_type == "TransformerDecoder": |
|
|
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, dropout, activation, normalize_before) |
|
|
norm = nn.LayerNorm(d_model) |
|
|
decoder = TransformerDecoder( |
|
|
decoder_layer, |
|
|
num_decoder_layers, |
|
|
norm, |
|
|
return_intermediate_dec, |
|
|
) |
|
|
elif network_type == "DynamicFCDecoder": |
|
|
d_style = d_model |
|
|
decoder_layer = DynamicFCDecoderLayer( |
|
|
d_model, |
|
|
nhead, |
|
|
d_style, |
|
|
dynamic_K, |
|
|
dynamic_ratio, |
|
|
dim_feedforward, |
|
|
dropout, |
|
|
activation, |
|
|
normalize_before, |
|
|
) |
|
|
norm = nn.LayerNorm(d_model) |
|
|
decoder = DynamicFCDecoder(decoder_layer, num_decoder_layers, norm, return_intermediate_dec) |
|
|
else: |
|
|
raise ValueError(f"Invalid network_type {network_type}") |
|
|
|
|
|
return decoder |
|
|
|
|
|
|
|
|
class DisentangleDecoder(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model=512, |
|
|
nhead=8, |
|
|
num_decoder_layers=3, |
|
|
dim_feedforward=2048, |
|
|
dropout=0.1, |
|
|
activation="relu", |
|
|
normalize_before=False, |
|
|
return_intermediate_dec=False, |
|
|
pos_embed_len=80, |
|
|
upper_face3d_indices=tuple(list(range(19)) + list(range(46, 51))), |
|
|
lower_face3d_indices=tuple(range(19, 46)), |
|
|
network_type="None", |
|
|
dynamic_K=None, |
|
|
dynamic_ratio=None, |
|
|
**_, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.upper_face3d_indices = upper_face3d_indices |
|
|
self.lower_face3d_indices = lower_face3d_indices |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.upper_decoder = get_decoder_network( |
|
|
network_type, |
|
|
d_model, |
|
|
nhead, |
|
|
dim_feedforward, |
|
|
dropout, |
|
|
activation, |
|
|
normalize_before, |
|
|
num_decoder_layers, |
|
|
return_intermediate_dec, |
|
|
dynamic_K, |
|
|
dynamic_ratio, |
|
|
) |
|
|
_reset_parameters(self.upper_decoder) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lower_decoder = get_decoder_network( |
|
|
network_type, |
|
|
d_model, |
|
|
nhead, |
|
|
dim_feedforward, |
|
|
dropout, |
|
|
activation, |
|
|
normalize_before, |
|
|
num_decoder_layers, |
|
|
return_intermediate_dec, |
|
|
dynamic_K, |
|
|
dynamic_ratio, |
|
|
) |
|
|
_reset_parameters(self.lower_decoder) |
|
|
|
|
|
self.pos_embed = PositionalEncoding(d_model, pos_embed_len) |
|
|
|
|
|
tail_hidden_dim = d_model // 2 |
|
|
self.upper_tail_fc = nn.Sequential( |
|
|
nn.Linear(d_model, tail_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(tail_hidden_dim, tail_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(tail_hidden_dim, len(upper_face3d_indices)), |
|
|
) |
|
|
self.lower_tail_fc = nn.Sequential( |
|
|
nn.Linear(d_model, tail_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(tail_hidden_dim, tail_hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Linear(tail_hidden_dim, len(lower_face3d_indices)), |
|
|
) |
|
|
|
|
|
def forward(self, content, style_code): |
|
|
""" |
|
|
|
|
|
Args: |
|
|
content (_type_): (B, num_frames, window, C_dmodel) |
|
|
style_code (_type_): (B, C_dmodel) |
|
|
|
|
|
Returns: |
|
|
face3d: (B, L_clip, C_3dmm) |
|
|
""" |
|
|
B, N, W, C = content.shape |
|
|
style = style_code.reshape(B, 1, 1, C).expand(B, N, W, C) |
|
|
style = style.permute(2, 0, 1, 3).reshape(W, B * N, C) |
|
|
|
|
|
|
|
|
content = content.permute(2, 0, 1, 3).reshape(W, B * N, C) |
|
|
|
|
|
tgt = torch.zeros_like(style) |
|
|
pos_embed = self.pos_embed(W) |
|
|
pos_embed = pos_embed.permute(1, 0, 2) |
|
|
|
|
|
upper_face3d_feat = self.upper_decoder(tgt, content, pos=pos_embed, query_pos=style)[0] |
|
|
|
|
|
upper_face3d_feat = upper_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :] |
|
|
|
|
|
upper_face3d = self.upper_tail_fc(upper_face3d_feat) |
|
|
|
|
|
|
|
|
lower_face3d_feat = self.lower_decoder(tgt, content, pos=pos_embed, query_pos=style)[0] |
|
|
lower_face3d_feat = lower_face3d_feat.permute(1, 0, 2).reshape(B, N, W, C)[:, :, W // 2, :] |
|
|
lower_face3d = self.lower_tail_fc(lower_face3d_feat) |
|
|
C_exp = len(self.upper_face3d_indices) + len(self.lower_face3d_indices) |
|
|
face3d = torch.zeros(B, N, C_exp).to(upper_face3d) |
|
|
face3d[:, :, self.upper_face3d_indices] = upper_face3d |
|
|
face3d[:, :, self.lower_face3d_indices] = lower_face3d |
|
|
return face3d |
|
|
|