# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import csv import logging import re from argparse import Namespace from collections import defaultdict from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset from fairseq.data import data_utils as fairseq_data_utils from fairseq.data import encoders from fairseq.data.audio.audio_utils import get_features_or_waveform from fairseq.data.audio.data_cfg import S2TDataConfig from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import ( NoisyOverlapAugment, ) from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform logger = logging.getLogger(__name__) def _collate_frames( frames: List[torch.Tensor], is_audio_input: bool = False ) -> torch.Tensor: """ Convert a list of 2D frames into a padded 3D tensor Args: frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is length of i-th frame and f_dim is static dimension of features Returns: 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] """ max_len = max(frame.size(0) for frame in frames) if is_audio_input: out = frames[0].new_zeros((len(frames), max_len)) else: out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) for i, v in enumerate(frames): out[i, : v.size(0)] = v return out def _is_int_or_np_int(n): return isinstance(n, int) or ( isinstance(n, np.generic) and isinstance(n.item(), int) ) @dataclass class SpeechToTextDatasetItem(object): index: int source: torch.Tensor target: Optional[torch.Tensor] = None speaker_id: Optional[int] = None class SpeechToTextDataset(FairseqDataset): LANG_TAG_TEMPLATE = "" def __init__( self, split: str, is_train_split: bool, cfg: S2TDataConfig, audio_paths: List[str], n_frames: List[int], src_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None, speakers: Optional[List[str]] = None, src_langs: Optional[List[str]] = None, tgt_langs: Optional[List[str]] = None, ids: Optional[List[str]] = None, tgt_dict: Optional[Dictionary] = None, pre_tokenizer=None, bpe_tokenizer=None, n_frames_per_step=1, speaker_to_id=None, append_eos=True, ): self.split, self.is_train_split = split, is_train_split self.cfg = cfg self.audio_paths, self.n_frames = audio_paths, n_frames self.n_samples = len(audio_paths) assert len(n_frames) == self.n_samples > 0 assert src_texts is None or len(src_texts) == self.n_samples assert tgt_texts is None or len(tgt_texts) == self.n_samples assert speakers is None or len(speakers) == self.n_samples assert src_langs is None or len(src_langs) == self.n_samples assert tgt_langs is None or len(tgt_langs) == self.n_samples assert ids is None or len(ids) == self.n_samples assert (tgt_dict is None and tgt_texts is None) or ( tgt_dict is not None and tgt_texts is not None ) self.src_texts, self.tgt_texts = src_texts, tgt_texts self.src_langs, self.tgt_langs = src_langs, tgt_langs self.speakers = speakers self.tgt_dict = tgt_dict self.check_tgt_lang_tag() self.ids = ids self.shuffle = cfg.shuffle if is_train_split else False self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( self.cfg.get_feature_transforms(split, is_train_split) ) self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict( self.cfg.get_waveform_transforms(split, is_train_split) ) # TODO: add these to data_cfg.py self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict( self.cfg.get_dataset_transforms(split, is_train_split) ) # check proper usage of transforms if self.feature_transforms and self.cfg.use_audio_input: logger.warning( "Feature transforms will not be applied. To use feature transforms, " "set use_audio_input as False in config." ) self.pre_tokenizer = pre_tokenizer self.bpe_tokenizer = bpe_tokenizer self.n_frames_per_step = n_frames_per_step self.speaker_to_id = speaker_to_id self.tgt_lens = self.get_tgt_lens_and_check_oov() self.append_eos = append_eos logger.info(self.__repr__()) def get_tgt_lens_and_check_oov(self): if self.tgt_texts is None: return [0 for _ in range(self.n_samples)] tgt_lens = [] n_tokens, n_oov_tokens = 0, 0 for i in range(self.n_samples): tokenized = self.get_tokenized_tgt_text(i).split(" ") oov_tokens = [ t for t in tokenized if self.tgt_dict.index(t) == self.tgt_dict.unk_index ] n_tokens += len(tokenized) n_oov_tokens += len(oov_tokens) tgt_lens.append(len(tokenized)) logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") return tgt_lens def __repr__(self): return ( self.__class__.__name__ + f'(split="{self.split}", n_samples={self.n_samples:_}, ' f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " f"n_frames_per_step={self.n_frames_per_step}, " f"shuffle={self.shuffle}, " f"feature_transforms={self.feature_transforms}, " f"waveform_transforms={self.waveform_transforms}, " f"dataset_transforms={self.dataset_transforms})" ) @classmethod def is_lang_tag(cls, token): pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") return re.match(pattern, token) def check_tgt_lang_tag(self): if self.cfg.prepend_tgt_lang_tag: assert self.tgt_langs is not None and self.tgt_dict is not None tgt_lang_tags = [ self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) ] assert all(t in self.tgt_dict for t in tgt_lang_tags) @classmethod def tokenize(cls, tokenizer, text: str): return text if tokenizer is None else tokenizer.encode(text) def get_tokenized_tgt_text(self, index: Union[int, List[int]]): if _is_int_or_np_int(index): text = self.tgt_texts[index] else: text = " ".join([self.tgt_texts[i] for i in index]) text = self.tokenize(self.pre_tokenizer, text) text = self.tokenize(self.bpe_tokenizer, text) return text def pack_frames(self, feature: torch.Tensor): if self.n_frames_per_step == 1: return feature n_packed_frames = feature.shape[0] // self.n_frames_per_step feature = feature[: self.n_frames_per_step * n_packed_frames] return feature.reshape(n_packed_frames, -1) @classmethod def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) assert lang_tag_idx != dictionary.unk() return lang_tag_idx def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor: """ Gives source audio for given index with any relevant transforms applied. For ConcatAug, source audios for given indices are concatenated in given order. Args: index (int or List[int]): index—or in the case of ConcatAug, indices—to pull the source audio for Returns: source audios concatenated for given indices with relevant transforms appplied """ if _is_int_or_np_int(index): source = get_features_or_waveform( self.audio_paths[index], need_waveform=self.cfg.use_audio_input, use_sample_rate=self.cfg.use_sample_rate, waveform_transforms=self.waveform_transforms, ) else: source = np.concatenate( [ get_features_or_waveform( self.audio_paths[i], need_waveform=self.cfg.use_audio_input, use_sample_rate=self.cfg.use_sample_rate, waveform_transforms=self.waveform_transforms, ) for i in index ] ) if self.cfg.use_audio_input: source = torch.from_numpy(source).float() if self.cfg.standardize_audio: with torch.no_grad(): source = F.layer_norm(source, source.shape) else: if self.feature_transforms is not None: source = self.feature_transforms(source) source = torch.from_numpy(source).float() return source def __getitem__(self, index: int) -> SpeechToTextDatasetItem: has_concat = self.dataset_transforms.has_transform(ConcatAugment) if has_concat: concat = self.dataset_transforms.get_transform(ConcatAugment) indices = concat.find_indices(index, self.n_frames, self.n_samples) source = self._get_source_audio(indices if has_concat else index) source = self.pack_frames(source) target = None if self.tgt_texts is not None: tokenized = self.get_tokenized_tgt_text(indices if has_concat else index) target = self.tgt_dict.encode_line( tokenized, add_if_not_exist=False, append_eos=self.append_eos ).long() if self.cfg.prepend_tgt_lang_tag: lang_tag_idx = self.get_lang_tag_idx( self.tgt_langs[index], self.tgt_dict ) target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) if self.cfg.prepend_bos_and_append_tgt_lang_tag: bos = torch.LongTensor([self.tgt_dict.bos()]) lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) assert lang_tag_idx != self.tgt_dict.unk() lang_tag_idx = torch.LongTensor([lang_tag_idx]) target = torch.cat((bos, target, lang_tag_idx), 0) speaker_id = None if self.speaker_to_id is not None: speaker_id = self.speaker_to_id[self.speakers[index]] return SpeechToTextDatasetItem( index=index, source=source, target=target, speaker_id=speaker_id ) def __len__(self): return self.n_samples def collater( self, samples: List[SpeechToTextDatasetItem], return_order: bool = False ) -> Dict: if len(samples) == 0: return {} indices = torch.tensor([x.index for x in samples], dtype=torch.long) sources = [x.source for x in samples] has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment) if has_NOAug and self.cfg.use_audio_input: NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment) sources = NOAug(sources) frames = _collate_frames(sources, self.cfg.use_audio_input) # sort samples by descending number of frames n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long) n_frames, order = n_frames.sort(descending=True) indices = indices.index_select(0, order) frames = frames.index_select(0, order) target, target_lengths = None, None prev_output_tokens = None ntokens = None if self.tgt_texts is not None: target = fairseq_data_utils.collate_tokens( [x.target for x in samples], self.tgt_dict.pad(), self.tgt_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) target = target.index_select(0, order) target_lengths = torch.tensor( [x.target.size(0) for x in samples], dtype=torch.long ).index_select(0, order) prev_output_tokens = fairseq_data_utils.collate_tokens( [x.target for x in samples], self.tgt_dict.pad(), eos_idx=None, left_pad=False, move_eos_to_beginning=True, ) prev_output_tokens = prev_output_tokens.index_select(0, order) ntokens = sum(x.target.size(0) for x in samples) speaker = None if self.speaker_to_id is not None: speaker = ( torch.tensor([s.speaker_id for s in samples], dtype=torch.long) .index_select(0, order) .view(-1, 1) ) net_input = { "src_tokens": frames, "src_lengths": n_frames, "prev_output_tokens": prev_output_tokens, } out = { "id": indices, "net_input": net_input, "speaker": speaker, "target": target, "target_lengths": target_lengths, "ntokens": ntokens, "nsentences": len(samples), } if return_order: out["order"] = order return out def num_tokens(self, index): return self.n_frames[index] def size(self, index): return self.n_frames[index], self.tgt_lens[index] @property def sizes(self): return np.array(self.n_frames) @property def can_reuse_epoch_itr_across_epochs(self): return True def ordered_indices(self): if self.shuffle: order = [np.random.permutation(len(self))] else: order = [np.arange(len(self))] # first by descending order of # of frames then by original/random order order.append([-n for n in self.n_frames]) return np.lexsort(order) def prefetch(self, indices): raise False class TextTargetMultitaskData(object): # mandatory columns KEY_ID, KEY_TEXT = "id", "tgt_text" LANG_TAG_TEMPLATE = "" def __init__(self, args, split, tgt_dict): samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples} self.dict = tgt_dict self.append_eos = args.decoder_type != "ctc" self.pre_tokenizer = self.build_tokenizer(args) self.bpe_tokenizer = self.build_bpe(args) self.prepend_bos_and_append_tgt_lang_tag = ( args.prepend_bos_and_append_tgt_lang_tag ) self.eos_token = args.eos_token self.lang_tag_mapping = args.get_lang_tag_mapping @classmethod def is_lang_tag(cls, token): pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") return re.match(pattern, token) @classmethod def tokenize(cls, tokenizer, text: str): return text if tokenizer is None else tokenizer.encode(text) def get_tokenized_tgt_text(self, index: int): text = self.tokenize(self.pre_tokenizer, self.data[index]) text = self.tokenize(self.bpe_tokenizer, text) return text def get_lang_tag_idx(self, lang: str, dictionary: Dictionary): lang_tag = self.LANG_TAG_TEMPLATE.format(lang) lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag) lang_tag_idx = dictionary.index(lang_tag) assert lang_tag_idx != dictionary.unk(), (lang, lang_tag) return lang_tag_idx def build_tokenizer(self, args): pre_tokenizer = args.config.get("pre_tokenizer") if pre_tokenizer is not None: logger.info(f"pre-tokenizer: {pre_tokenizer}") return encoders.build_tokenizer(Namespace(**pre_tokenizer)) else: return None def build_bpe(self, args): bpe_tokenizer = args.config.get("bpe_tokenizer") if bpe_tokenizer is not None: logger.info(f"tokenizer: {bpe_tokenizer}") return encoders.build_bpe(Namespace(**bpe_tokenizer)) else: return None def get(self, sample_id, tgt_lang=None): if sample_id in self.data: tokenized = self.get_tokenized_tgt_text(sample_id) target = self.dict.encode_line( tokenized, add_if_not_exist=False, append_eos=self.append_eos, ) if self.prepend_bos_and_append_tgt_lang_tag: bos = torch.LongTensor([self.dict.bos()]) lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict) assert lang_tag_idx != self.dict.unk() lang_tag_idx = torch.LongTensor([lang_tag_idx]) target = torch.cat((bos, target, lang_tag_idx), 0) return target else: logger.warning(f"no target for {sample_id}") return torch.IntTensor([]) def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: out = fairseq_data_utils.collate_tokens( samples, self.dict.pad(), eos_idx=None, left_pad=False, move_eos_to_beginning=False, ).long() prev_out = fairseq_data_utils.collate_tokens( samples, self.dict.pad(), eos_idx=None, left_pad=False, move_eos_to_beginning=True, ).long() target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long) ntokens = sum(t.size(0) for t in samples) output = { "prev_output_tokens": prev_out, "target": out, "target_lengths": target_lengths, "ntokens": ntokens, } return output class SpeechToTextMultitaskDataset(SpeechToTextDataset): def __init__(self, **kwargs): super().__init__(**kwargs) self.multitask_data = {} def add_multitask_dataset(self, task_name, task_data): self.multitask_data[task_name] = task_data def __getitem__( self, index: int ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]: s2t_data = super().__getitem__(index) multitask_target = {} sample_id = self.ids[index] tgt_lang = self.tgt_langs[index] for task_name, task_dataset in self.multitask_data.items(): multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) return s2t_data, multitask_target def collater( self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]] ) -> Dict: if len(samples) == 0: return {} out = super().collater([s for s, _ in samples], return_order=True) order = out["order"] del out["order"] for task_name, task_dataset in self.multitask_data.items(): if "multitask" not in out: out["multitask"] = {} d = [s[task_name] for _, s in samples] task_target = task_dataset.collater(d) out["multitask"][task_name] = { "target": task_target["target"].index_select(0, order), "target_lengths": task_target["target_lengths"].index_select(0, order), "ntokens": task_target["ntokens"], } out["multitask"][task_name]["net_input"] = { "prev_output_tokens": task_target["prev_output_tokens"].index_select( 0, order ), } return out class SpeechToTextDatasetCreator(object): # mandatory columns KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" KEY_TGT_TEXT = "tgt_text" # optional columns KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" # default values DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" @classmethod def _from_list( cls, split_name: str, is_train_split, samples: List[Dict], cfg: S2TDataConfig, tgt_dict, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, multitask: Optional[Dict] = None, ) -> SpeechToTextDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] has_multitask = multitask is not None and len(multitask.keys()) > 0 dataset_cls = ( SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset ) ds = dataset_cls( split=split_name, is_train_split=is_train_split, cfg=cfg, audio_paths=audio_paths, n_frames=n_frames, src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, speaker_to_id=speaker_to_id, ) if has_multitask: for task_name, task_obj in multitask.items(): task_data = TextTargetMultitaskData( task_obj.args, split_name, task_obj.target_dictionary ) ds.add_multitask_dataset(task_name, task_data) return ds @classmethod def get_size_ratios( cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 ) -> List[float]: """Size ratios for temperature-based sampling (https://arxiv.org/abs/1907.05019)""" id_to_lp, lp_to_sz = {}, defaultdict(int) for ds in datasets: lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} assert len(lang_pairs) == 1 lang_pair = list(lang_pairs)[0] id_to_lp[ds.split] = lang_pair lp_to_sz[lang_pair] += sum(ds.n_frames) sz_sum = sum(v for v in lp_to_sz.values()) lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()} prob_sum = sum(v for v in lp_to_tgt_prob.values()) lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} lp_to_sz_ratio = { k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() } size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] p_formatted = { k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz } logger.info(f"sampling probability balancing: {p_formatted}") sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} logger.info(f"balanced sampling size ratio: {sr_formatted}") return size_ratio @classmethod def _load_samples_from_tsv(cls, root: str, split: str): tsv_path = Path(root) / f"{split}.tsv" if not tsv_path.is_file(): raise FileNotFoundError(f"Dataset not found: {tsv_path}") with open(tsv_path) as f: reader = csv.DictReader( f, delimiter="\t", quotechar=None, doublequote=False, lineterminator="\n", quoting=csv.QUOTE_NONE, ) samples = [dict(e) for e in reader] if len(samples) == 0: raise ValueError(f"Empty manifest: {tsv_path}") return samples @classmethod def _from_tsv( cls, root: str, cfg: S2TDataConfig, split: str, tgt_dict, is_train_split: bool, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, multitask: Optional[Dict] = None, ) -> SpeechToTextDataset: samples = cls._load_samples_from_tsv(root, split) return cls._from_list( split, is_train_split, samples, cfg, tgt_dict, pre_tokenizer, bpe_tokenizer, n_frames_per_step, speaker_to_id, multitask, ) @classmethod def from_tsv( cls, root: str, cfg: S2TDataConfig, splits: str, tgt_dict, pre_tokenizer, bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, n_frames_per_step: int = 1, speaker_to_id=None, multitask: Optional[Dict] = None, ) -> SpeechToTextDataset: datasets = [ cls._from_tsv( root=root, cfg=cfg, split=split, tgt_dict=tgt_dict, is_train_split=is_train_split, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, n_frames_per_step=n_frames_per_step, speaker_to_id=speaker_to_id, multitask=multitask, ) for split in splits.split(",") ] if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) datasets = [ ResamplingDataset( d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) ) for r, d in zip(size_ratios, datasets) ] return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]