| import matplotlib
|
| matplotlib.use('Agg')
|
|
|
| import torch
|
| import numpy as np
|
| import os
|
|
|
| from tasks.base_task import BaseDataset
|
| from tasks.tts.fs2 import FastSpeech2Task
|
| from modules.fastspeech.pe import PitchExtractor
|
| import utils
|
| from utils.indexed_datasets import IndexedDataset
|
| from utils.hparams import hparams
|
| from utils.plot import f0_to_figure
|
| from utils.pitch_utils import norm_interp_f0, denorm_f0
|
|
|
|
|
| class PeDataset(BaseDataset):
|
| def __init__(self, prefix, shuffle=False):
|
| super().__init__(shuffle)
|
| self.data_dir = hparams['binary_data_dir']
|
| self.prefix = prefix
|
| self.hparams = hparams
|
| self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy')
|
| self.indexed_ds = None
|
|
|
|
|
| f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy'
|
| if os.path.exists(f0_stats_fn):
|
| hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn)
|
| hparams['f0_mean'] = float(hparams['f0_mean'])
|
| hparams['f0_std'] = float(hparams['f0_std'])
|
| else:
|
| hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None
|
|
|
| if prefix == 'test':
|
| if hparams['num_test_samples'] > 0:
|
| self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids']
|
| self.sizes = [self.sizes[i] for i in self.avail_idxs]
|
|
|
| def _get_item(self, index):
|
| if hasattr(self, 'avail_idxs') and self.avail_idxs is not None:
|
| index = self.avail_idxs[index]
|
| if self.indexed_ds is None:
|
| self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}')
|
| return self.indexed_ds[index]
|
|
|
| def __getitem__(self, index):
|
| hparams = self.hparams
|
| item = self._get_item(index)
|
| max_frames = hparams['max_frames']
|
| spec = torch.Tensor(item['mel'])[:max_frames]
|
|
|
| f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams)
|
| pitch = torch.LongTensor(item.get("pitch"))[:max_frames]
|
|
|
| sample = {
|
| "id": index,
|
| "item_name": item['item_name'],
|
| "text": item['txt'],
|
| "mel": spec,
|
| "pitch": pitch,
|
| "f0": f0,
|
| "uv": uv,
|
|
|
|
|
| }
|
| return sample
|
|
|
| def collater(self, samples):
|
| if len(samples) == 0:
|
| return {}
|
| id = torch.LongTensor([s['id'] for s in samples])
|
| item_names = [s['item_name'] for s in samples]
|
| text = [s['text'] for s in samples]
|
| f0 = utils.collate_1d([s['f0'] for s in samples], 0.0)
|
| pitch = utils.collate_1d([s['pitch'] for s in samples])
|
| uv = utils.collate_1d([s['uv'] for s in samples])
|
| mels = utils.collate_2d([s['mel'] for s in samples], 0.0)
|
| mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples])
|
|
|
|
|
|
|
|
|
| batch = {
|
| 'id': id,
|
| 'item_name': item_names,
|
| 'nsamples': len(samples),
|
| 'text': text,
|
| 'mels': mels,
|
| 'mel_lengths': mel_lengths,
|
| 'pitch': pitch,
|
|
|
|
|
| 'f0': f0,
|
| 'uv': uv,
|
| }
|
| return batch
|
|
|
|
|
| class PitchExtractionTask(FastSpeech2Task):
|
| def __init__(self):
|
| super().__init__()
|
| self.dataset_cls = PeDataset
|
|
|
| def build_tts_model(self):
|
| self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers'])
|
|
|
|
|
|
|
| def _training_step(self, sample, batch_idx, _):
|
| loss_output = self.run_model(self.model, sample)
|
| total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad])
|
| loss_output['batch_size'] = sample['mels'].size()[0]
|
| return total_loss, loss_output
|
|
|
| def validation_step(self, sample, batch_idx):
|
| outputs = {}
|
| outputs['losses'] = {}
|
| outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True)
|
| outputs['total_loss'] = sum(outputs['losses'].values())
|
| outputs['nsamples'] = sample['nsamples']
|
| outputs = utils.tensors_to_scalars(outputs)
|
| if batch_idx < hparams['num_valid_plots']:
|
| self.plot_pitch(batch_idx, model_out, sample)
|
| return outputs
|
|
|
| def run_model(self, model, sample, return_output=False, infer=False):
|
| f0 = sample['f0']
|
| uv = sample['uv']
|
| output = model(sample['mels'])
|
| losses = {}
|
| self.add_pitch_loss(output, sample, losses)
|
| if not return_output:
|
| return losses
|
| else:
|
| return losses, output
|
|
|
| def plot_pitch(self, batch_idx, model_out, sample):
|
| gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams)
|
| self.logger.experiment.add_figure(
|
| f'f0_{batch_idx}',
|
| f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]),
|
| self.global_step)
|
|
|
| def add_pitch_loss(self, output, sample, losses):
|
|
|
| mel = sample['mels']
|
| f0 = sample['f0']
|
| uv = sample['uv']
|
|
|
|
|
| nonpadding = (mel.abs().sum(-1) > 0).float()
|
|
|
| self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) |