| import torch.nn as nn | |
| import torch | |
| from models.util import MyResNet34 | |
| class audio2poseLSTM(nn.Module): | |
| def __init__(self): | |
| super(audio2poseLSTM,self).__init__() | |
| self.em_pose = MyResNet34(256, 1) | |
| self.em_audio = MyResNet34(256, 1) | |
| self.lstm = nn.LSTM(512,256,num_layers=2,bias=True,batch_first=True) | |
| self.output = nn.Linear(256,6) | |
| def forward(self,x): | |
| pose_em = self.em_pose(x["img"]) | |
| bs = pose_em.shape[0] | |
| zero_state = torch.zeros((2, bs, 256), requires_grad=True).to(pose_em.device) | |
| cur_state = (zero_state, zero_state) | |
| img_em = pose_em | |
| bs,seqlen,num,dims = x["audio"].shape | |
| audio = x["audio"].reshape(-1, 1, num, dims) | |
| audio_em = self.em_audio(audio).reshape(bs, seqlen, 256) | |
| result = [self.output(img_em).unsqueeze(1)] | |
| for i in range(seqlen): | |
| img_em,cur_state = self.lstm(torch.cat((audio_em[:,i:i+1],img_em.unsqueeze(1)),dim=2),cur_state) | |
| img_em = img_em.reshape(-1, 256) | |
| result.append(self.output(img_em).unsqueeze(1)) | |
| res = torch.cat(result,dim=1) | |
| return res | |