| |
| |
| |
| |
|
|
| import os |
| import torch |
| from tqdm import tqdm |
| from collections import OrderedDict |
|
|
| from models.tts.base.tts_inferece import TTSInference |
| from models.tts.jets.jets_dataset import JetsTestDataset, JetsTestCollator |
| from utils.util import load_config |
| from utils.io import save_audio |
| from models.tts.jets.jets import Jets |
| from models.vocoders.vocoder_inference import synthesis |
| from pathlib import Path |
| from processors.phone_extractor import phoneExtractor |
| from text.text_token_collation import phoneIDCollation |
| import numpy as np |
| import json |
| import time |
|
|
|
|
| class JetsInference(TTSInference): |
| def __init__(self, args, cfg): |
| TTSInference.__init__(self, args, cfg) |
| self.args = args |
| self.cfg = cfg |
| self.infer_type = args.mode |
|
|
| def _build_model(self): |
| self.model = Jets(self.cfg) |
| return self.model |
|
|
| def _build_test_dataset(self): |
| return JetsTestDataset, JetsTestCollator |
|
|
| def inference_for_batches(self): |
| |
| n_batch = len(self.test_dataloader) |
| now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) |
| print( |
| "Model eval time: {}, batch_size = {}, n_batch = {}".format( |
| now, self.test_batch_size, n_batch |
| ) |
| ) |
| self.model.eval() |
|
|
| |
| pred_res = [] |
| with torch.no_grad(): |
| for i, batch_data in enumerate( |
| self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) |
| ): |
| outputs = self.model.inference(batch_data) |
|
|
| audios, d_predictions = outputs |
| d_predictions = d_predictions.unsqueeze(-1) |
|
|
| for idx in range(audios.size(0)): |
| audio = audios[idx, 0, :].data.cpu().float() |
| duration = d_predictions[idx, :, :] |
| audio_length = ( |
| duration.sum([0, 1]).long() * self.cfg.preprocess.hop_size |
| ) |
| audio_length = audio_length.cpu().numpy() |
| audio = audio[:audio_length] |
| pred_res.append(audio) |
|
|
| return pred_res |
|
|