| | import torch.nn as nn |
| | import torch |
| | from models.util import mydownres2Dblock |
| | import numpy as np |
| | from models.util import AntiAliasInterpolation2d,make_coordinate_grid |
| | from sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d |
| | import torch.nn.functional as F |
| | import copy |
| |
|
| |
|
| | class PositionalEncoding(nn.Module): |
| |
|
| | def __init__(self, d_hid, n_position=200): |
| | super(PositionalEncoding, self).__init__() |
| |
|
| | |
| | self.register_buffer('pos_table', self._get_sinusoid_encoding_table(n_position, d_hid)) |
| |
|
| | def _get_sinusoid_encoding_table(self, n_position, d_hid): |
| | ''' Sinusoid position encoding table ''' |
| | |
| |
|
| | def get_position_angle_vec(position): |
| | return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)] |
| |
|
| | sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)]) |
| | sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) |
| | sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) |
| |
|
| | return torch.FloatTensor(sinusoid_table).unsqueeze(0) |
| |
|
| | def forward(self, winsize): |
| | return self.pos_table[:, :winsize].clone().detach() |
| |
|
| | def _get_activation_fn(activation): |
| | """Return an activation function given a string""" |
| | if activation == "relu": |
| | return F.relu |
| | if activation == "gelu": |
| | return F.gelu |
| | if activation == "glu": |
| | return F.glu |
| | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") |
| |
|
| | def _get_clones(module, N): |
| | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) |
| |
|
| | class Transformer(nn.Module): |
| |
|
| | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, |
| | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, |
| | activation="relu", normalize_before=False, |
| | return_intermediate_dec=True): |
| | super().__init__() |
| |
|
| | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, |
| | dropout, activation, normalize_before) |
| | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None |
| | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) |
| |
|
| | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, |
| | dropout, activation, normalize_before) |
| | decoder_norm = nn.LayerNorm(d_model) |
| | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, |
| | return_intermediate=return_intermediate_dec) |
| |
|
| | self._reset_parameters() |
| |
|
| | self.d_model = d_model |
| | self.nhead = nhead |
| |
|
| | def _reset_parameters(self): |
| | for p in self.parameters(): |
| | if p.dim() > 1: |
| | nn.init.xavier_uniform_(p) |
| |
|
| | def forward(self,opt, src, query_embed, pos_embed): |
| | |
| |
|
| | src = src.permute(1, 0, 2) |
| | pos_embed = pos_embed.permute(1, 0, 2) |
| | query_embed = query_embed.permute(1, 0, 2) |
| |
|
| | tgt = torch.zeros_like(query_embed) |
| | memory = self.encoder(src, pos=pos_embed) |
| |
|
| | hs = self.decoder(tgt, memory, |
| | pos=pos_embed, query_pos=query_embed) |
| | return hs |
| |
|
| |
|
| | class TransformerEncoder(nn.Module): |
| |
|
| | def __init__(self, encoder_layer, num_layers, norm=None): |
| | super().__init__() |
| | self.layers = _get_clones(encoder_layer, num_layers) |
| | self.num_layers = num_layers |
| | self.norm = norm |
| |
|
| | def forward(self, src, mask = None, src_key_padding_mask = None, pos = None): |
| | output = src+pos |
| |
|
| | for layer in self.layers: |
| | output = layer(output, src_mask=mask, |
| | src_key_padding_mask=src_key_padding_mask, pos=pos) |
| |
|
| | if self.norm is not None: |
| | output = self.norm(output) |
| |
|
| | return output |
| |
|
| |
|
| | class TransformerDecoder(nn.Module): |
| |
|
| | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): |
| | super().__init__() |
| | self.layers = _get_clones(decoder_layer, num_layers) |
| | self.num_layers = num_layers |
| | self.norm = norm |
| | self.return_intermediate = return_intermediate |
| |
|
| | def forward(self, tgt, memory, tgt_mask = None, memory_mask = None, tgt_key_padding_mask = None, |
| | memory_key_padding_mask = None, |
| | pos = None, |
| | query_pos = None): |
| | output = tgt+pos+query_pos |
| |
|
| | intermediate = [] |
| |
|
| | for layer in self.layers: |
| | output = layer(output, memory, tgt_mask=tgt_mask, |
| | memory_mask=memory_mask, |
| | tgt_key_padding_mask=tgt_key_padding_mask, |
| | memory_key_padding_mask=memory_key_padding_mask, |
| | pos=pos, query_pos=query_pos) |
| | if self.return_intermediate: |
| | intermediate.append(self.norm(output)) |
| |
|
| | if self.norm is not None: |
| | output = self.norm(output) |
| | if self.return_intermediate: |
| | intermediate.pop() |
| | intermediate.append(output) |
| |
|
| | if self.return_intermediate: |
| | return torch.stack(intermediate) |
| |
|
| | return output.unsqueeze(0) |
| |
|
| |
|
| | class TransformerEncoderLayer(nn.Module): |
| |
|
| | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| | activation="relu", normalize_before=False): |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | |
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| |
|
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| |
|
| | self.activation = _get_activation_fn(activation) |
| | self.normalize_before = normalize_before |
| |
|
| | def with_pos_embed(self, tensor, pos): |
| | return tensor if pos is None else tensor + pos |
| |
|
| | def forward_post(self, |
| | src, |
| | src_mask = None, |
| | src_key_padding_mask = None, |
| | pos = None): |
| | |
| | src2 = self.self_attn(src, src, value=src, attn_mask=src_mask, |
| | key_padding_mask=src_key_padding_mask)[0] |
| | src = src + self.dropout1(src2) |
| | src = self.norm1(src) |
| | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| | src = src + self.dropout2(src2) |
| | src = self.norm2(src) |
| | return src |
| |
|
| | def forward_pre(self, src, |
| | src_mask = None, |
| | src_key_padding_mask = None, |
| | pos = None): |
| | src2 = self.norm1(src) |
| | |
| | src2 = self.self_attn(src2, src2, value=src2, attn_mask=src_mask, |
| | key_padding_mask=src_key_padding_mask)[0] |
| | src = src + self.dropout1(src2) |
| | src2 = self.norm2(src) |
| | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) |
| | src = src + self.dropout2(src2) |
| | return src |
| |
|
| | def forward(self, src, |
| | src_mask = None, |
| | src_key_padding_mask = None, |
| | pos = None): |
| | if self.normalize_before: |
| | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) |
| | return self.forward_post(src, src_mask, src_key_padding_mask, pos) |
| |
|
| |
|
| | class TransformerDecoderLayer(nn.Module): |
| |
|
| | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, |
| | activation="relu", normalize_before=False): |
| | super().__init__() |
| | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
| | |
| | self.linear1 = nn.Linear(d_model, dim_feedforward) |
| | self.dropout = nn.Dropout(dropout) |
| | self.linear2 = nn.Linear(dim_feedforward, d_model) |
| |
|
| | self.norm1 = nn.LayerNorm(d_model) |
| | self.norm2 = nn.LayerNorm(d_model) |
| | self.norm3 = nn.LayerNorm(d_model) |
| | self.dropout1 = nn.Dropout(dropout) |
| | self.dropout2 = nn.Dropout(dropout) |
| | self.dropout3 = nn.Dropout(dropout) |
| |
|
| | self.activation = _get_activation_fn(activation) |
| | self.normalize_before = normalize_before |
| |
|
| | def with_pos_embed(self, tensor, pos): |
| | return tensor if pos is None else tensor + pos |
| |
|
| | def forward_post(self, tgt, memory, |
| | tgt_mask = None, |
| | memory_mask = None, |
| | tgt_key_padding_mask = None, |
| | memory_key_padding_mask = None, |
| | pos = None, |
| | query_pos = None): |
| | |
| | tgt2 = self.self_attn(tgt, tgt, value=tgt, attn_mask=tgt_mask, |
| | key_padding_mask=tgt_key_padding_mask)[0] |
| | tgt = tgt + self.dropout1(tgt2) |
| | tgt = self.norm1(tgt) |
| | tgt2 = self.multihead_attn(query=tgt, |
| | key=memory, |
| | value=memory, attn_mask=memory_mask, |
| | key_padding_mask=memory_key_padding_mask)[0] |
| | tgt = tgt + self.dropout2(tgt2) |
| | tgt = self.norm2(tgt) |
| | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) |
| | tgt = tgt + self.dropout3(tgt2) |
| | tgt = self.norm3(tgt) |
| | return tgt |
| |
|
| | def forward_pre(self, tgt, memory, |
| | tgt_mask = None, |
| | memory_mask = None, |
| | tgt_key_padding_mask = None, |
| | memory_key_padding_mask = None, |
| | pos = None, |
| | query_pos = None): |
| | tgt2 = self.norm1(tgt) |
| | |
| | tgt2 = self.self_attn(tgt2, tgt2, value=tgt2, attn_mask=tgt_mask, |
| | key_padding_mask=tgt_key_padding_mask)[0] |
| | tgt = tgt + self.dropout1(tgt2) |
| | tgt2 = self.norm2(tgt) |
| | tgt2 = self.multihead_attn(query=tgt2, |
| | key=memory, |
| | value=memory, attn_mask=memory_mask, |
| | key_padding_mask=memory_key_padding_mask)[0] |
| | tgt = tgt + self.dropout2(tgt2) |
| | tgt2 = self.norm3(tgt) |
| | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) |
| | tgt = tgt + self.dropout3(tgt2) |
| | return tgt |
| |
|
| | def forward(self, tgt, memory, |
| | tgt_mask = None, |
| | memory_mask = None, |
| | tgt_key_padding_mask = None, |
| | memory_key_padding_mask = None, |
| | pos = None, |
| | query_pos = None): |
| | if self.normalize_before: |
| | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, |
| | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
| | return self.forward_post(tgt, memory, tgt_mask, memory_mask, |
| | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) |
| |
|
| |
|
| |
|
| | class Audio2kpTransformer(nn.Module): |
| | def __init__(self,opt): |
| | super(Audio2kpTransformer, self).__init__() |
| | self.opt = opt |
| |
|
| |
|
| | self.embedding = nn.Embedding(41, opt.embedding_dim) |
| | self.pos_enc = PositionalEncoding(512,20) |
| | self.down_pose = AntiAliasInterpolation2d(1,0.25) |
| | input_dim = 2 |
| | self.feature_extract = nn.Sequential(mydownres2Dblock(input_dim,32), |
| | mydownres2Dblock(32,64), |
| | mydownres2Dblock(64,128), |
| | mydownres2Dblock(128,256), |
| | mydownres2Dblock(256,512), |
| | nn.AvgPool2d(2)) |
| |
|
| | self.decode_dim = 70 |
| | self.audio_embedding = nn.Sequential(nn.ConvTranspose2d(1, 8, (29, 14), stride=(1, 1), padding=(0, 11)), |
| | BatchNorm2d(8), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(8, 35, (13, 13), stride=(1, 1), padding=(6, 6))) |
| | self.decodefeature_extract = nn.Sequential(mydownres2Dblock(self.decode_dim,32), |
| | mydownres2Dblock(32,64), |
| | mydownres2Dblock(64,128), |
| | mydownres2Dblock(128,256), |
| | mydownres2Dblock(256,512), |
| | nn.AvgPool2d(2)) |
| |
|
| | self.transformer = Transformer() |
| | self.kp = nn.Linear(512,opt.num_kp*2) |
| | self.jacobian = nn.Linear(512,opt.num_kp*4) |
| | self.jacobian.weight.data.zero_() |
| | self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 1] * self.opt.num_kp, dtype=torch.float)) |
| | self.criterion = nn.L1Loss() |
| |
|
| | def create_sparse_motions(self, source_image, kp_source): |
| | """ |
| | Eq 4. in the paper T_{s<-d}(z) |
| | """ |
| | bs, _, h, w = source_image.shape |
| | identity_grid = make_coordinate_grid((h, w), type=kp_source['value'].type()) |
| | identity_grid = identity_grid.view(1, 1, h, w, 2) |
| | coordinate_grid = identity_grid |
| | if 'jacobian' in kp_source: |
| | jacobian = kp_source['jacobian'] |
| | jacobian = jacobian.unsqueeze(-3).unsqueeze(-3) |
| | jacobian = jacobian.repeat(1, 1, h, w, 1, 1) |
| | coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1)) |
| | coordinate_grid = coordinate_grid.squeeze(-1) |
| |
|
| | driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.opt.num_kp, 1, 1, 2) |
| |
|
| | |
| | identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1) |
| | sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1) |
| |
|
| | return sparse_motions.permute(0,1,4,2,3).reshape(bs,(self.opt.num_kp+1)*2,64,64) |
| |
|
| |
|
| |
|
| | def forward(self,x, initial_kp = None): |
| | bs,seqlen = x["ph"].shape |
| | ph = x["ph"].reshape(bs*seqlen,1) |
| | pose = x["pose"].reshape(bs*seqlen,1,256,256) |
| | input_feature = self.down_pose(pose) |
| |
|
| | phoneme_embedding = self.embedding(ph.long()) |
| | phoneme_embedding = phoneme_embedding.reshape(bs*seqlen, 1, 16, 16) |
| | phoneme_embedding = F.interpolate(phoneme_embedding, scale_factor=4) |
| | input_feature = torch.cat((input_feature, phoneme_embedding), dim=1) |
| |
|
| | input_feature = self.feature_extract(input_feature).unsqueeze(-1).reshape(bs,seqlen,512) |
| |
|
| | audio = x["audio"].reshape(bs * seqlen, 1, 4, 41) |
| | decoder_feature = self.audio_embedding(audio) |
| | decoder_feature = F.interpolate(decoder_feature, scale_factor=2) |
| | decoder_feature = self.decodefeature_extract(torch.cat( |
| | (decoder_feature, |
| | initial_kp["feature_map"].unsqueeze(1).repeat(1, seqlen, 1, 1, 1).reshape(bs * seqlen, 35, 64, 64)), |
| | dim=1)).unsqueeze(-1).reshape(bs, seqlen, 512) |
| |
|
| | posi_em = self.pos_enc(self.opt.num_w*2+1) |
| |
|
| |
|
| | out = {} |
| |
|
| | output_feature = self.transformer(self.opt,input_feature,decoder_feature,posi_em)[-1,self.opt.num_w] |
| |
|
| | out["value"] = self.kp(output_feature).reshape(bs,self.opt.num_kp,2) |
| | out["jacobian"] = self.jacobian(output_feature).reshape(bs,self.opt.num_kp,2,2) |
| |
|
| | return out |
| |
|
| |
|
| |
|