| from multiprocessing.pool import Pool
|
|
|
| import matplotlib
|
|
|
| from utils.pl_utils import data_loader
|
| from utils.training_utils import RSQRTSchedule
|
| from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder
|
| from modules.fastspeech.pe import PitchExtractor
|
|
|
| matplotlib.use('Agg')
|
| import os
|
| import numpy as np
|
| from tqdm import tqdm
|
| import torch.distributed as dist
|
|
|
| from tasks.base_task import BaseTask
|
| from utils.hparams import hparams
|
| from utils.text_encoder import TokenTextEncoder
|
| import json
|
|
|
| import torch
|
| import torch.optim
|
| import torch.utils.data
|
| import utils
|
|
|
|
|
|
|
| class TtsTask(BaseTask):
|
| def __init__(self, *args, **kwargs):
|
| self.vocoder = None
|
| self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir'])
|
| self.padding_idx = self.phone_encoder.pad()
|
| self.eos_idx = self.phone_encoder.eos()
|
| self.seg_idx = self.phone_encoder.seg()
|
| self.saving_result_pool = None
|
| self.saving_results_futures = None
|
| self.stats = {}
|
| super().__init__(*args, **kwargs)
|
|
|
| def build_scheduler(self, optimizer):
|
| return RSQRTSchedule(optimizer)
|
|
|
| def build_optimizer(self, model):
|
| self.optimizer = optimizer = torch.optim.AdamW(
|
| model.parameters(),
|
| lr=hparams['lr'])
|
| return optimizer
|
|
|
| def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None,
|
| required_batch_size_multiple=-1, endless=False, batch_by_size=True):
|
| devices_cnt = torch.cuda.device_count()
|
| if devices_cnt == 0:
|
| devices_cnt = 1
|
| if required_batch_size_multiple == -1:
|
| required_batch_size_multiple = devices_cnt
|
|
|
| def shuffle_batches(batches):
|
| np.random.shuffle(batches)
|
| return batches
|
|
|
| if max_tokens is not None:
|
| max_tokens *= devices_cnt
|
| if max_sentences is not None:
|
| max_sentences *= devices_cnt
|
| indices = dataset.ordered_indices()
|
| if batch_by_size:
|
| batch_sampler = utils.batch_by_size(
|
| indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences,
|
| required_batch_size_multiple=required_batch_size_multiple,
|
| )
|
| else:
|
| batch_sampler = []
|
| for i in range(0, len(indices), max_sentences):
|
| batch_sampler.append(indices[i:i + max_sentences])
|
|
|
| if shuffle:
|
| batches = shuffle_batches(list(batch_sampler))
|
| if endless:
|
| batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))]
|
| else:
|
| batches = batch_sampler
|
| if endless:
|
| batches = [b for _ in range(1000) for b in batches]
|
| num_workers = dataset.num_workers
|
| if self.trainer.use_ddp:
|
| num_replicas = dist.get_world_size()
|
| rank = dist.get_rank()
|
| batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0]
|
| return torch.utils.data.DataLoader(dataset,
|
| collate_fn=dataset.collater,
|
| batch_sampler=batches,
|
| num_workers=num_workers,
|
| pin_memory=False)
|
|
|
| def build_phone_encoder(self, data_dir):
|
| phone_list_file = os.path.join(data_dir, 'phone_set.json')
|
|
|
| phone_list = json.load(open(phone_list_file))
|
| return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',')
|
|
|
| def build_optimizer(self, model):
|
| self.optimizer = optimizer = torch.optim.AdamW(
|
| model.parameters(),
|
| lr=hparams['lr'])
|
| return optimizer
|
|
|
| def test_start(self):
|
| self.saving_result_pool = Pool(8)
|
| self.saving_results_futures = []
|
| self.vocoder: BaseVocoder = get_vocoder_cls(hparams)()
|
| if hparams.get('pe_enable') is not None and hparams['pe_enable']:
|
| self.pe = PitchExtractor().cuda()
|
| utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True)
|
| self.pe.eval()
|
| def test_end(self, outputs):
|
| self.saving_result_pool.close()
|
| [f.get() for f in tqdm(self.saving_results_futures)]
|
| self.saving_result_pool.join()
|
| return {}
|
|
|
|
|
|
|
|
|
| def weights_nonzero_speech(self, target):
|
|
|
|
|
| dim = target.size(-1)
|
| return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim)
|
|
|
| if __name__ == '__main__':
|
| TtsTask.start()
|
|
|