| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| |
|
| |
|
| | class Conv2d(nn.Module): |
| | def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout)) |
| | self.act = nn.ReLU() |
| |
|
| | def forward(self, x): |
| | out = self.conv_block(x) |
| | return self.act(out) |
| |
|
| |
|
| | class Conv2d_res(nn.Module): |
| | |
| | def __init__(self, cin, cout, kernel_size, stride, padding, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout)) |
| | self.act = nn.ReLU() |
| |
|
| | def forward(self, x): |
| | out = self.conv_block(x) |
| | out += x |
| | return self.act(out) |
| |
|
| |
|
| | class Conv2dTranspose(nn.Module): |
| | def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs): |
| | super().__init__(*args, **kwargs) |
| | self.conv_block = nn.Sequential( |
| | nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding), |
| | nn.BatchNorm2d(cout), |
| | ) |
| | self.act = nn.ReLU() |
| |
|
| | def forward(self, x): |
| | out = self.conv_block(x) |
| | return self.act(out) |
| |
|
| |
|
| | class FETE_model(nn.Module): |
| | def __init__(self): |
| | super(FETE_model, self).__init__() |
| |
|
| | self.face_encoder_blocks = nn.ModuleList( |
| | [ |
| | nn.Sequential(Conv2d(6, 16, kernel_size=7, stride=2, padding=3)), |
| | nn.Sequential( |
| | Conv2d(16, 32, kernel_size=3, stride=2, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2d(32, 64, kernel_size=3, stride=2, padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2d(128, 256, kernel_size=3, stride=2, padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2d(256, 512, kernel_size=3, stride=2, padding=1), |
| | Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2d(512, 512, kernel_size=3, stride=2, padding=0), |
| | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), |
| | ), |
| | ] |
| | ) |
| |
|
| | self.audio_encoder = nn.Sequential( |
| | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d(64, 128, kernel_size=3, stride=3, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Conv2d(256, 512, kernel_size=3, stride=1, padding=0), |
| | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), |
| | ) |
| |
|
| | self.pose_encoder = nn.Sequential( |
| | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d(64, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Conv2d(256, 512, kernel_size=3, stride=2, padding=0), |
| | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), |
| | ) |
| |
|
| | self.emotion_encoder = nn.Sequential( |
| | Conv2d(1, 32, kernel_size=7, stride=1, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d(64, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Conv2d(256, 512, kernel_size=3, stride=2, padding=0), |
| | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), |
| | ) |
| |
|
| | self.blink_encoder = nn.Sequential( |
| | Conv2d(1, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(32, 32, kernel_size=3, stride=1, padding=1), |
| | Conv2d(32, 64, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d(64, 128, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d(128, 256, kernel_size=3, stride=(1, 2), padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Conv2d(256, 512, kernel_size=1, stride=(1, 2), padding=0), |
| | Conv2d(512, 512, kernel_size=1, stride=1, padding=0), |
| | ) |
| |
|
| | self.face_decoder_blocks = nn.ModuleList( |
| | [ |
| | nn.Sequential( |
| | Conv2d(2048, 512, kernel_size=1, stride=1, padding=0), |
| | ), |
| | nn.Sequential( |
| | Conv2dTranspose(1024, 512, kernel_size=4, stride=1, padding=0), |
| | Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2dTranspose(1024, 512, kernel_size=3, stride=2, padding=1, output_padding=1), |
| | Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(512, 512, kernel_size=3, stride=1, padding=1), |
| | Self_Attention(512, 512), |
| | ), |
| | nn.Sequential( |
| | Conv2dTranspose(768, 384, kernel_size=3, stride=2, padding=1, output_padding=1), |
| | Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(384, 384, kernel_size=3, stride=1, padding=1), |
| | Self_Attention(384, 384), |
| | ), |
| | nn.Sequential( |
| | Conv2dTranspose(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(256, 256, kernel_size=3, stride=1, padding=1), |
| | Self_Attention(256, 256), |
| | ), |
| | nn.Sequential( |
| | Conv2dTranspose(320, 128, kernel_size=3, stride=2, padding=1, output_padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(128, 128, kernel_size=3, stride=1, padding=1), |
| | ), |
| | nn.Sequential( |
| | Conv2dTranspose(160, 64, kernel_size=3, stride=2, padding=1, output_padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | Conv2d_res(64, 64, kernel_size=3, stride=1, padding=1), |
| | ), |
| | ] |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | self.output_block = nn.Sequential( |
| | Conv2dTranspose(80, 32, kernel_size=3, stride=2, padding=1, output_padding=1), |
| | nn.Conv2d(32, 3, kernel_size=1, stride=1, padding=0), |
| | nn.Sigmoid(), |
| | ) |
| |
|
| | def forward( |
| | self, |
| | face_sequences, |
| | audio_sequences, |
| | pose_sequences, |
| | emotion_sequences, |
| | blink_sequences, |
| | ): |
| | |
| | B = audio_sequences.size(0) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | audio_embedding = self.audio_encoder(audio_sequences) |
| | pose_embedding = self.pose_encoder(pose_sequences) |
| | emotion_embedding = self.emotion_encoder(emotion_sequences) |
| | blink_embedding = self.blink_encoder(blink_sequences) |
| | inputs_embedding = torch.cat((audio_embedding, pose_embedding, emotion_embedding, blink_embedding), dim=1) |
| | |
| |
|
| | feats = [] |
| | x = face_sequences |
| | for f in self.face_encoder_blocks: |
| | x = f(x) |
| | |
| | feats.append(x) |
| |
|
| | x = inputs_embedding |
| | for f in self.face_decoder_blocks: |
| | x = f(x) |
| | |
| |
|
| | |
| | x = torch.cat((x, feats[-1]), dim=1) |
| | |
| | |
| | |
| | |
| | feats.pop() |
| |
|
| | x = self.output_block(x) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | outputs = x |
| |
|
| | return outputs |
| |
|
| |
|
| | class Self_Attention(nn.Module): |
| | """ |
| | Source-Reference Attention Layer |
| | """ |
| |
|
| | def __init__(self, in_planes_s, in_planes_r): |
| | """ |
| | Parameters |
| | ---------- |
| | in_planes_s: int |
| | Number of input source feature vector channels. |
| | in_planes_r: int |
| | Number of input reference feature vector channels. |
| | """ |
| | super(Self_Attention, self).__init__() |
| | self.query_conv = nn.Conv2d(in_channels=in_planes_s, out_channels=in_planes_s // 8, kernel_size=1) |
| | self.key_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r // 8, kernel_size=1) |
| | self.value_conv = nn.Conv2d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1) |
| | self.gamma = nn.Parameter(torch.zeros(1)) |
| | self.softmax = nn.Softmax(dim=-1) |
| |
|
| | def forward(self, source): |
| | source = source.float() if isinstance(source, torch.cuda.HalfTensor) else source |
| | reference = source |
| | """ |
| | Parameters |
| | ---------- |
| | source : torch.Tensor |
| | Source feature maps (B x Cs x Ts x Hs x Ws) |
| | reference : torch.Tensor |
| | Reference feature maps (B x Cr x Tr x Hr x Wr ) |
| | Returns : |
| | torch.Tensor |
| | Source-reference attention value added to the input source features |
| | torch.Tensor |
| | Attention map (B x Ns x Nt) (Ns=Ts*Hs*Ws, Nr=Tr*Hr*Wr) |
| | """ |
| | s_batchsize, sC, sH, sW = source.size() |
| | r_batchsize, rC, rH, rW = reference.size() |
| |
|
| | proj_query = self.query_conv(source).view(s_batchsize, -1, sH * sW).permute(0, 2, 1) |
| | proj_key = self.key_conv(reference).view(r_batchsize, -1, rW * rH) |
| | energy = torch.bmm(proj_query, proj_key) |
| | attention = self.softmax(energy) |
| | proj_value = self.value_conv(reference).view(r_batchsize, -1, rH * rW) |
| | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) |
| | out = out.view(s_batchsize, sC, sH, sW) |
| | out = self.gamma * out + source |
| | return out.half() if isinstance(source, torch.cuda.FloatTensor) else out |
| |
|