File size: 2,349 Bytes
5c69097 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | import torch
import torch.nn as nn
from model.audioEncoder import audioEncoder
from model.visualEncoder import visualFrontend, visualTCN, visualConv1D
from model.attentionLayer import attentionLayer
class talkNetModel(nn.Module):
def __init__(self):
super(talkNetModel, self).__init__()
# Visual Temporal Encoder
self.visualFrontend = visualFrontend() # Visual Frontend
# self.visualFrontend.load_state_dict(torch.load('visual_frontend.pt', map_location="cuda"))
# for param in self.visualFrontend.parameters():
# param.requires_grad = False
self.visualTCN = visualTCN() # Visual Temporal Network TCN
self.visualConv1D = visualConv1D() # Visual Temporal Network Conv1d
# Audio Temporal Encoder
self.audioEncoder = audioEncoder(layers = [3, 4, 6, 3], num_filters = [16, 32, 64, 128])
# Audio-visual Cross Attention
self.crossA2V = attentionLayer(d_model = 128, nhead = 8)
self.crossV2A = attentionLayer(d_model = 128, nhead = 8)
# Audio-visual Self Attention
self.selfAV = attentionLayer(d_model = 256, nhead = 8)
def forward_visual_frontend(self, x):
B, T, W, H = x.shape
x = x.view(B*T, 1, 1, W, H)
x = (x / 255 - 0.4161) / 0.1688
x = self.visualFrontend(x)
x = x.view(B, T, 512)
x = x.transpose(1,2)
x = self.visualTCN(x)
x = self.visualConv1D(x)
x = x.transpose(1,2)
return x
def forward_audio_frontend(self, x):
x = x.unsqueeze(1).transpose(2, 3)
x = self.audioEncoder(x)
return x
def forward_cross_attention(self, x1, x2):
x1_c = self.crossA2V(src = x1, tar = x2)
x2_c = self.crossV2A(src = x2, tar = x1)
return x1_c, x2_c
def forward_audio_visual_backend(self, x1, x2):
x = torch.cat((x1,x2), 2)
x = self.selfAV(src = x, tar = x)
x = torch.reshape(x, (-1, 256))
return x
def forward_audio_backend(self,x):
x = torch.reshape(x, (-1, 128))
return x
def forward_visual_backend(self,x):
x = torch.reshape(x, (-1, 128))
return x
|