| import torch |
| from torch import nn |
| from src.audio2pose_models.cvae import CVAE |
| from src.audio2pose_models.discriminator import PoseSequenceDiscriminator |
| from src.audio2pose_models.audio_encoder import AudioEncoder |
|
|
| class Audio2Pose(nn.Module): |
| def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): |
| super().__init__() |
| self.cfg = cfg |
| self.seq_len = cfg.MODEL.CVAE.SEQ_LEN |
| self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE |
| self.device = device |
|
|
| self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) |
| self.audio_encoder.eval() |
| for param in self.audio_encoder.parameters(): |
| param.requires_grad = False |
|
|
| self.netG = CVAE(cfg) |
| self.netD_motion = PoseSequenceDiscriminator(cfg) |
| |
| |
| def forward(self, x): |
|
|
| batch = {} |
| coeff_gt = x['gt'].cuda().squeeze(0) |
| batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] |
| batch['ref'] = coeff_gt[:, 0, 64:70] |
| batch['class'] = x['class'].squeeze(0).cuda() |
| indiv_mels= x['indiv_mels'].cuda().squeeze(0) |
|
|
| |
| audio_emb_list = [] |
| audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) |
| batch['audio_emb'] = audio_emb |
| batch = self.netG(batch) |
|
|
| pose_motion_pred = batch['pose_motion_pred'] |
| pose_gt = coeff_gt[:, 1:, 64:70].clone() |
| pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred |
|
|
| batch['pose_pred'] = pose_pred |
| batch['pose_gt'] = pose_gt |
|
|
| return batch |
|
|
| def test(self, x): |
|
|
| batch = {} |
| ref = x['ref'] |
| batch['ref'] = x['ref'][:,0,-6:] |
| batch['class'] = x['class'] |
| bs = ref.shape[0] |
| |
| indiv_mels= x['indiv_mels'] |
| indiv_mels_use = indiv_mels[:, 1:] |
| num_frames = x['num_frames'] |
| num_frames = int(num_frames) - 1 |
|
|
| |
| div = num_frames//self.seq_len |
| re = num_frames%self.seq_len |
| audio_emb_list = [] |
| pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, |
| device=batch['ref'].device)] |
|
|
| for i in range(div): |
| z = torch.randn(bs, self.latent_dim).to(ref.device) |
| batch['z'] = z |
| audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) |
| batch['audio_emb'] = audio_emb |
| batch = self.netG.test(batch) |
| pose_motion_pred_list.append(batch['pose_motion_pred']) |
| |
| if re != 0: |
| z = torch.randn(bs, self.latent_dim).to(ref.device) |
| batch['z'] = z |
| audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) |
| if audio_emb.shape[1] != self.seq_len: |
| pad_dim = self.seq_len-audio_emb.shape[1] |
| pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) |
| audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) |
| batch['audio_emb'] = audio_emb |
| batch = self.netG.test(batch) |
| pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) |
| |
| pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) |
| batch['pose_motion_pred'] = pose_motion_pred |
|
|
| pose_pred = ref[:, :1, -6:] + pose_motion_pred |
|
|
| batch['pose_pred'] = pose_pred |
| return batch |
|
|