| | import torch |
| | from torch import nn |
| | from src.audio2pose_models.cvae import CVAE |
| | from src.audio2pose_models.discriminator import PoseSequenceDiscriminator |
| | from src.audio2pose_models.audio_encoder import AudioEncoder |
| |
|
| | class Audio2Pose(nn.Module): |
| | def __init__(self, cfg, wav2lip_checkpoint, device='cuda'): |
| | super().__init__() |
| | self.cfg = cfg |
| | self.seq_len = cfg.MODEL.CVAE.SEQ_LEN |
| | self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE |
| | self.device = device |
| |
|
| | self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device) |
| | self.audio_encoder.eval() |
| | for param in self.audio_encoder.parameters(): |
| | param.requires_grad = False |
| |
|
| | self.netG = CVAE(cfg) |
| | self.netD_motion = PoseSequenceDiscriminator(cfg) |
| | |
| | |
| | def forward(self, x): |
| |
|
| | batch = {} |
| | coeff_gt = x['gt'].cuda().squeeze(0) |
| | batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] |
| | batch['ref'] = coeff_gt[:, 0, 64:70] |
| | batch['class'] = x['class'].squeeze(0).cuda() |
| | indiv_mels= x['indiv_mels'].cuda().squeeze(0) |
| |
|
| | |
| | audio_emb_list = [] |
| | audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) |
| | batch['audio_emb'] = audio_emb |
| | batch = self.netG(batch) |
| |
|
| | pose_motion_pred = batch['pose_motion_pred'] |
| | pose_gt = coeff_gt[:, 1:, 64:70].clone() |
| | pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred |
| |
|
| | batch['pose_pred'] = pose_pred |
| | batch['pose_gt'] = pose_gt |
| |
|
| | return batch |
| |
|
| | def test(self, x): |
| |
|
| | batch = {} |
| | ref = x['ref'] |
| | batch['ref'] = x['ref'][:,0,-6:] |
| | batch['class'] = x['class'] |
| | bs = ref.shape[0] |
| | |
| | indiv_mels= x['indiv_mels'] |
| | indiv_mels_use = indiv_mels[:, 1:] |
| | num_frames = x['num_frames'] |
| | num_frames = int(num_frames) - 1 |
| |
|
| | |
| | div = num_frames//self.seq_len |
| | re = num_frames%self.seq_len |
| | audio_emb_list = [] |
| | pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, |
| | device=batch['ref'].device)] |
| |
|
| | for i in range(div): |
| | z = torch.randn(bs, self.latent_dim).to(ref.device) |
| | batch['z'] = z |
| | audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) |
| | batch['audio_emb'] = audio_emb |
| | batch = self.netG.test(batch) |
| | pose_motion_pred_list.append(batch['pose_motion_pred']) |
| | |
| | if re != 0: |
| | z = torch.randn(bs, self.latent_dim).to(ref.device) |
| | batch['z'] = z |
| | audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) |
| | if audio_emb.shape[1] != self.seq_len: |
| | pad_dim = self.seq_len-audio_emb.shape[1] |
| | pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) |
| | audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) |
| | batch['audio_emb'] = audio_emb |
| | batch = self.netG.test(batch) |
| | pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:]) |
| | |
| | pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1) |
| | batch['pose_motion_pred'] = pose_motion_pred |
| |
|
| | pose_pred = ref[:, :1, -6:] + pose_motion_pred |
| |
|
| | batch['pose_pred'] = pose_pred |
| | return batch |
| |
|