# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging from pathlib import Path from typing import Dict, List, NamedTuple, Optional import torch from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset from fairseq.data import data_utils as fairseq_data_utils from fairseq.data.audio.speech_to_text_dataset import ( S2TDataConfig, SpeechToTextDataset, SpeechToTextDatasetCreator, ) logger = logging.getLogger(__name__) class S2TJointDataConfig(S2TDataConfig): """Wrapper class for data config YAML""" @property def src_vocab_filename(self): """fairseq vocabulary file under data root""" return self.config.get("src_vocab_filename", "src_dict.txt") @property def src_pre_tokenizer(self) -> Dict: """Pre-tokenizer to apply before subword tokenization. Returning a dictionary with `tokenizer` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" return self.config.get("src_pre_tokenizer", {"tokenizer": None}) @property def src_bpe_tokenizer(self) -> Dict: """Subword tokenizer to apply on source text after pre-tokenization. Returning a dictionary with `bpe` providing the tokenizer name and the other items providing the tokenizer-specific arguments. Tokenizers are defined in `fairseq.data.encoders.*`""" return self.config.get("src_bpe_tokenizer", {"bpe": None}) @property def prepend_tgt_lang_tag_no_change(self) -> bool: """Prepend target lang ID token as the prev_output_tokens BOS (e.g. for to-many multilingual setting). No change needed during inference. This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos. """ value = self.config.get("prepend_tgt_lang_tag_no_change", None) if value is None: return self.config.get("prepend_tgt_lang_tag_as_bos", False) return value @property def sampling_text_alpha(self): """Hyper-parameter alpha = 1/T for temperature-based resampling. (text input only) (alpha = 1 for no resampling)""" return self.config.get("sampling_text_alpha", 1.0) class SpeechToTextJointDatasetItem(NamedTuple): index: int source: torch.Tensor target: Optional[torch.Tensor] = None src_txt_tokens: Optional[torch.Tensor] = None tgt_lang_tag: Optional[int] = None src_lang_tag: Optional[int] = None tgt_alignment: Optional[torch.Tensor] = None # use_src_lang_id: # 0: don't use src_lang_id # 1: attach src_lang_id to the src_txt_tokens as eos class SpeechToTextJointDataset(SpeechToTextDataset): def __init__( self, split: str, is_train_split: bool, cfg: S2TJointDataConfig, audio_paths: List[str], n_frames: List[int], src_texts: Optional[List[str]] = None, tgt_texts: Optional[List[str]] = None, speakers: Optional[List[str]] = None, src_langs: Optional[List[str]] = None, tgt_langs: Optional[List[str]] = None, ids: Optional[List[str]] = None, tgt_dict: Optional[Dictionary] = None, src_dict: Optional[Dictionary] = None, pre_tokenizer=None, bpe_tokenizer=None, src_pre_tokenizer=None, src_bpe_tokenizer=None, append_eos: Optional[bool] = True, alignment: Optional[List[str]] = None, use_src_lang_id: Optional[int] = 0, ): super().__init__( split, is_train_split, cfg, audio_paths, n_frames, src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, tgt_dict=tgt_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, append_eos=append_eos, ) self.src_dict = src_dict self.src_pre_tokenizer = src_pre_tokenizer self.src_bpe_tokenizer = src_bpe_tokenizer self.alignment = None self.use_src_lang_id = use_src_lang_id if alignment is not None: self.alignment = [ [float(s) for s in sample.split()] for sample in alignment ] def get_tokenized_src_text(self, index: int): text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index]) text = self.tokenize(self.src_bpe_tokenizer, text) return text def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem: s2t_dataset_item = super().__getitem__(index) src_tokens = None src_lang_tag = None if self.src_texts is not None and self.src_dict is not None: src_tokens = self.get_tokenized_src_text(index) src_tokens = self.src_dict.encode_line( src_tokens, add_if_not_exist=False, append_eos=True ).long() if self.use_src_lang_id > 0: src_lang_tag = self.get_lang_tag_idx( self.src_langs[index], self.src_dict ) tgt_lang_tag = None if self.cfg.prepend_tgt_lang_tag_no_change: # prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) ali = None if self.alignment is not None: ali = torch.Tensor(self.alignment[index]).float() return SpeechToTextJointDatasetItem( index=index, source=s2t_dataset_item.source, target=s2t_dataset_item.target, src_txt_tokens=src_tokens, tgt_lang_tag=tgt_lang_tag, src_lang_tag=src_lang_tag, tgt_alignment=ali, ) def __len__(self): return self.n_samples def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict: s2t_out = super().collater(samples, return_order=True) if s2t_out == {}: return s2t_out net_input, order = s2t_out["net_input"], s2t_out["order"] if self.src_texts is not None and self.src_dict is not None: src_txt_tokens = fairseq_data_utils.collate_tokens( [x.src_txt_tokens for x in samples], self.src_dict.pad(), self.src_dict.eos(), left_pad=False, move_eos_to_beginning=False, ) src_txt_lengths = torch.tensor( [x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long ) if self.use_src_lang_id > 0: src_lang_idxs = torch.tensor( [s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype ) if self.use_src_lang_id == 1: # replace eos with lang_id eos_idx = src_txt_lengths - 1 src_txt_tokens.scatter_( 1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1) ) else: raise NotImplementedError("Implementation is required") src_txt_tokens = src_txt_tokens.index_select(0, order) src_txt_lengths = src_txt_lengths.index_select(0, order) net_input["src_txt_tokens"] = src_txt_tokens net_input["src_txt_lengths"] = src_txt_lengths net_input["alignment"] = None if self.alignment is not None: max_len = max([s.tgt_alignment.size(0) for s in samples]) alignment = torch.ones(len(samples), max_len).float() for i, s in enumerate(samples): cur_len = s.tgt_alignment.size(0) alignment[i][:cur_len].copy_(s.tgt_alignment) net_input["alignment"] = alignment.index_select(0, order) if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None: for i in range(len(samples)): net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag out = { "id": s2t_out["id"], "net_input": net_input, "target": s2t_out["target"], "target_lengths": s2t_out["target_lengths"], "ntokens": s2t_out["ntokens"], "nsentences": len(samples), } return out class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator): KEY_ALIGN = "align" @classmethod def _from_list( cls, split_name: str, is_train_split, samples: List[Dict], cfg: S2TJointDataConfig, tgt_dict, src_dict, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, append_eos, use_src_lang_id, ) -> SpeechToTextJointDataset: audio_root = Path(cfg.audio_root) ids = [s[cls.KEY_ID] for s in samples] audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] tgt_alignment = None if cls.KEY_ALIGN in samples[0].keys(): tgt_alignment = [s[cls.KEY_ALIGN] for s in samples] return SpeechToTextJointDataset( split_name, is_train_split, cfg, audio_paths, n_frames, src_texts=src_texts, tgt_texts=tgt_texts, speakers=speakers, src_langs=src_langs, tgt_langs=tgt_langs, ids=ids, tgt_dict=tgt_dict, src_dict=src_dict, pre_tokenizer=pre_tokenizer, bpe_tokenizer=bpe_tokenizer, src_pre_tokenizer=src_pre_tokenizer, src_bpe_tokenizer=src_bpe_tokenizer, append_eos=append_eos, alignment=tgt_alignment, use_src_lang_id=use_src_lang_id, ) @classmethod def _from_tsv( cls, root: str, cfg: S2TJointDataConfig, split: str, tgt_dict, src_dict, is_train_split: bool, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, append_eos: bool, use_src_lang_id: int, ) -> SpeechToTextJointDataset: samples = cls._load_samples_from_tsv(root, split) return cls._from_list( split, is_train_split, samples, cfg, tgt_dict, src_dict, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, append_eos, use_src_lang_id, ) @classmethod def from_tsv( cls, root: str, cfg: S2TJointDataConfig, splits: str, tgt_dict, src_dict, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, is_train_split: bool, epoch: int, seed: int, append_eos: Optional[bool] = True, use_src_lang_id: Optional[int] = 0, ) -> SpeechToTextJointDataset: datasets = [ cls._from_tsv( root, cfg, split, tgt_dict, src_dict, is_train_split, pre_tokenizer, bpe_tokenizer, src_pre_tokenizer, src_bpe_tokenizer, append_eos=append_eos, use_src_lang_id=use_src_lang_id, ) for split in splits.split(",") ] if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: # temperature-based sampling size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) datasets = [ ResamplingDataset( d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) ) for r, d in zip(size_ratios, datasets) ] return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0]