# Copyright (c) 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the LICENSE file in # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory.abs from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, List, Optional import numpy as np import torch from fairseq.data import Dictionary from fairseq.data import data_utils as fairseq_data_utils from fairseq.data.audio.audio_utils import get_features_or_waveform from fairseq.data.audio.speech_to_text_dataset import ( S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, _collate_frames, ) @dataclass class TextToSpeechDatasetItem(object): index: int source: torch.Tensor target: Optional[torch.Tensor] = None speaker_id: Optional[int] = None duration: Optional[torch.Tensor] = None pitch: Optional[torch.Tensor] = None energy: Optional[torch.Tensor] = None class TextToSpeechDataset(SpeechToTextDataset): def __init__( self, split: str, is_train_split: bool, cfg: S2TDataConfig, audio_paths: List[str], n_frames: List[int], src_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None, speakers: Optional[List[str]] = None, src_langs: Optional[List[str]] = None, tgt_langs: Optional[List[str]] = None, ids: Optional[List[str]] = None, tgt_dict: Optional[Dictionary] = None, pre_tokenizer=None, bpe_tokenizer=None, n_frames_per_step=1, speaker_to_id=None, durations: Optional[List[List[int]]] = None, pitches: Optional[List[str]] = None, energies: Optional[List[str]] = None, ): super(TextToSpeechDataset, self).__init__( split, is_train_split, cfg, audio_paths, n_frames, src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, speaker_to_id=speaker_to_id, ) self.durations = durations self.pitches = pitches self.energies = energies def __getitem__(self, index: int) -> TextToSpeechDatasetItem: s2t_item = super().__getitem__(index) duration, pitch, energy = None, None, None if self.durations is not None: duration = torch.tensor( self.durations[index] + [0], dtype=torch.long # pad 0 for EOS ) if self.pitches is not None: pitch = get_features_or_waveform(self.pitches[index]) pitch = torch.from_numpy( np.concatenate((pitch, [0])) # pad 0 for EOS ).float() if self.energies is not None: energy = get_features_or_waveform(self.energies[index]) energy = torch.from_numpy( np.concatenate((energy, [0])) # pad 0 for EOS ).float() return TextToSpeechDatasetItem( index=index, source=s2t_item.source, target=s2t_item.target, speaker_id=s2t_item.speaker_id, duration=duration, pitch=pitch, energy=energy, ) def collater(self, samples: List[TextToSpeechDatasetItem]) -> Dict[str, Any]: if len(samples) == 0: return {} src_lengths, order = torch.tensor( [s.target.shape[0] for s in samples], dtype=torch.long ).sort(descending=True) id_ = torch.tensor([s.index for s in samples], dtype=torch.long).index_select( 0, order ) feat = _collate_frames( [s.source for s in samples], self.cfg.use_audio_input ).index_select(0, order) target_lengths = torch.tensor( [s.source.shape[0] for s in samples], dtype=torch.long ).index_select(0, order) src_tokens = fairseq_data_utils.collate_tokens( [s.target for s in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ).index_select(0, order) speaker = None if self.speaker_to_id is not None: speaker = ( torch.tensor([s.speaker_id for s in samples], dtype=torch.long) .index_select(0, order) .view(-1, 1) ) bsz, _, d = feat.size() prev_output_tokens = torch.cat( (feat.new_zeros((bsz, 1, d)), feat[:, :-1, :]), dim=1 ) durations, pitches, energies = None, None, None if self.durations is not None: durations = fairseq_data_utils.collate_tokens( [s.duration for s in samples], 0 ).index_select(0, order) assert src_tokens.shape[1] == durations.shape[1] if self.pitches is not None: pitches = _collate_frames([s.pitch for s in samples], True) pitches = pitches.index_select(0, order) assert src_tokens.shape[1] == pitches.shape[1] if self.energies is not None: energies = _collate_frames([s.energy for s in samples], True) energies = energies.index_select(0, order) assert src_tokens.shape[1] == energies.shape[1] src_texts = [self.tgt_dict.string(samples[i].target) for i in order] return { "id": id_, "net_input": { "src_tokens": src_tokens, "src_lengths": src_lengths, "prev_output_tokens": prev_output_tokens, }, "speaker": speaker, "target": feat, "durations": durations, "pitches": pitches, "energies": energies, "target_lengths": target_lengths, "ntokens": sum(target_lengths).item(), "nsentences": len(samples), "src_texts": src_texts, } class TextToSpeechDatasetCreator(SpeechToTextDatasetCreator): KEY_DURATION = "duration" KEY_PITCH = "pitch" KEY_ENERGY = "energy" @classmethod def _from_list( cls, split_name: str, is_train_split, samples: List[Dict], cfg: S2TDataConfig, tgt_dict, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, multitask=None, ) -> TextToSpeechDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] durations = [s.get(cls.KEY_DURATION, None) for s in samples] durations = [ None if dd is None else [int(d) for d in dd.split(" ")] for dd in durations ] durations = None if any(dd is None for dd in durations) else durations pitches = [s.get(cls.KEY_PITCH, None) for s in samples] pitches = [ None if pp is None else (audio_root / pp).as_posix() for pp in pitches ] pitches = None if any(pp is None for pp in pitches) else pitches energies = [s.get(cls.KEY_ENERGY, None) for s in samples] energies = [ None if ee is None else (audio_root / ee).as_posix() for ee in energies ] energies = None if any(ee is None for ee in energies) else energies return TextToSpeechDataset( split_name, is_train_split, cfg, audio_paths, n_frames, src_texts, tgt_texts, speakers, src_langs, tgt_langs, ids, tgt_dict, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, durations, pitches, energies, )