| | from tqdm import tqdm |
| | import torch |
| | from torch import nn |
| |
|
| |
|
| | class Audio2Exp(nn.Module): |
| | def __init__(self, netG, cfg, device, prepare_training_loss=False): |
| | super(Audio2Exp, self).__init__() |
| | self.cfg = cfg |
| | self.device = device |
| | self.netG = netG.to(device) |
| |
|
| | def test(self, batch): |
| |
|
| | mel_input = batch['indiv_mels'] |
| | bs = mel_input.shape[0] |
| | T = mel_input.shape[1] |
| |
|
| | exp_coeff_pred = [] |
| |
|
| | for i in tqdm(range(0, T, 10),'audio2exp:'): |
| | |
| | current_mel_input = mel_input[:,i:i+10] |
| |
|
| | |
| | ref = batch['ref'][:, :, :64][:, i:i+10] |
| | ratio = batch['ratio_gt'][:, i:i+10] |
| |
|
| | audiox = current_mel_input.view(-1, 1, 80, 16) |
| |
|
| | curr_exp_coeff_pred = self.netG(audiox, ref, ratio) |
| |
|
| | exp_coeff_pred += [curr_exp_coeff_pred] |
| |
|
| | |
| | results_dict = { |
| | 'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1) |
| | } |
| | return results_dict |
| |
|
| |
|
| |
|