Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import csv | |
| import logging | |
| import re | |
| from argparse import Namespace | |
| from collections import defaultdict | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from fairseq.data import ConcatDataset, Dictionary, FairseqDataset, ResamplingDataset | |
| from fairseq.data import data_utils as fairseq_data_utils | |
| from fairseq.data import encoders | |
| from fairseq.data.audio.audio_utils import get_features_or_waveform | |
| from fairseq.data.audio.data_cfg import S2TDataConfig | |
| from fairseq.data.audio.dataset_transforms import CompositeAudioDatasetTransform | |
| from fairseq.data.audio.dataset_transforms.concataugment import ConcatAugment | |
| from fairseq.data.audio.dataset_transforms.noisyoverlapaugment import ( | |
| NoisyOverlapAugment, | |
| ) | |
| from fairseq.data.audio.feature_transforms import CompositeAudioFeatureTransform | |
| from fairseq.data.audio.waveform_transforms import CompositeAudioWaveformTransform | |
| logger = logging.getLogger(__name__) | |
| def _collate_frames( | |
| frames: List[torch.Tensor], is_audio_input: bool = False | |
| ) -> torch.Tensor: | |
| """ | |
| Convert a list of 2D frames into a padded 3D tensor | |
| Args: | |
| frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is | |
| length of i-th frame and f_dim is static dimension of features | |
| Returns: | |
| 3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i] | |
| """ | |
| max_len = max(frame.size(0) for frame in frames) | |
| if is_audio_input: | |
| out = frames[0].new_zeros((len(frames), max_len)) | |
| else: | |
| out = frames[0].new_zeros((len(frames), max_len, frames[0].size(1))) | |
| for i, v in enumerate(frames): | |
| out[i, : v.size(0)] = v | |
| return out | |
| def _is_int_or_np_int(n): | |
| return isinstance(n, int) or ( | |
| isinstance(n, np.generic) and isinstance(n.item(), int) | |
| ) | |
| class SpeechToTextDatasetItem(object): | |
| index: int | |
| source: torch.Tensor | |
| target: Optional[torch.Tensor] = None | |
| speaker_id: Optional[int] = None | |
| class SpeechToTextDataset(FairseqDataset): | |
| LANG_TAG_TEMPLATE = "<lang:{}>" | |
| def __init__( | |
| self, | |
| split: str, | |
| is_train_split: bool, | |
| cfg: S2TDataConfig, | |
| audio_paths: List[str], | |
| n_frames: List[int], | |
| src_texts: Optional[List[str]] = None, | |
| tgt_texts: Optional[List[str]] = None, | |
| speakers: Optional[List[str]] = None, | |
| src_langs: Optional[List[str]] = None, | |
| tgt_langs: Optional[List[str]] = None, | |
| ids: Optional[List[str]] = None, | |
| tgt_dict: Optional[Dictionary] = None, | |
| pre_tokenizer=None, | |
| bpe_tokenizer=None, | |
| n_frames_per_step=1, | |
| speaker_to_id=None, | |
| append_eos=True, | |
| ): | |
| self.split, self.is_train_split = split, is_train_split | |
| self.cfg = cfg | |
| self.audio_paths, self.n_frames = audio_paths, n_frames | |
| self.n_samples = len(audio_paths) | |
| assert len(n_frames) == self.n_samples > 0 | |
| assert src_texts is None or len(src_texts) == self.n_samples | |
| assert tgt_texts is None or len(tgt_texts) == self.n_samples | |
| assert speakers is None or len(speakers) == self.n_samples | |
| assert src_langs is None or len(src_langs) == self.n_samples | |
| assert tgt_langs is None or len(tgt_langs) == self.n_samples | |
| assert ids is None or len(ids) == self.n_samples | |
| assert (tgt_dict is None and tgt_texts is None) or ( | |
| tgt_dict is not None and tgt_texts is not None | |
| ) | |
| self.src_texts, self.tgt_texts = src_texts, tgt_texts | |
| self.src_langs, self.tgt_langs = src_langs, tgt_langs | |
| self.speakers = speakers | |
| self.tgt_dict = tgt_dict | |
| self.check_tgt_lang_tag() | |
| self.ids = ids | |
| self.shuffle = cfg.shuffle if is_train_split else False | |
| self.feature_transforms = CompositeAudioFeatureTransform.from_config_dict( | |
| self.cfg.get_feature_transforms(split, is_train_split) | |
| ) | |
| self.waveform_transforms = CompositeAudioWaveformTransform.from_config_dict( | |
| self.cfg.get_waveform_transforms(split, is_train_split) | |
| ) | |
| # TODO: add these to data_cfg.py | |
| self.dataset_transforms = CompositeAudioDatasetTransform.from_config_dict( | |
| self.cfg.get_dataset_transforms(split, is_train_split) | |
| ) | |
| # check proper usage of transforms | |
| if self.feature_transforms and self.cfg.use_audio_input: | |
| logger.warning( | |
| "Feature transforms will not be applied. To use feature transforms, " | |
| "set use_audio_input as False in config." | |
| ) | |
| self.pre_tokenizer = pre_tokenizer | |
| self.bpe_tokenizer = bpe_tokenizer | |
| self.n_frames_per_step = n_frames_per_step | |
| self.speaker_to_id = speaker_to_id | |
| self.tgt_lens = self.get_tgt_lens_and_check_oov() | |
| self.append_eos = append_eos | |
| logger.info(self.__repr__()) | |
| def get_tgt_lens_and_check_oov(self): | |
| if self.tgt_texts is None: | |
| return [0 for _ in range(self.n_samples)] | |
| tgt_lens = [] | |
| n_tokens, n_oov_tokens = 0, 0 | |
| for i in range(self.n_samples): | |
| tokenized = self.get_tokenized_tgt_text(i).split(" ") | |
| oov_tokens = [ | |
| t | |
| for t in tokenized | |
| if self.tgt_dict.index(t) == self.tgt_dict.unk_index | |
| ] | |
| n_tokens += len(tokenized) | |
| n_oov_tokens += len(oov_tokens) | |
| tgt_lens.append(len(tokenized)) | |
| logger.info(f"'{self.split}' has {n_oov_tokens / n_tokens * 100:.2f}% OOV") | |
| return tgt_lens | |
| def __repr__(self): | |
| return ( | |
| self.__class__.__name__ | |
| + f'(split="{self.split}", n_samples={self.n_samples:_}, ' | |
| f"prepend_tgt_lang_tag={self.cfg.prepend_tgt_lang_tag}, " | |
| f"n_frames_per_step={self.n_frames_per_step}, " | |
| f"shuffle={self.shuffle}, " | |
| f"feature_transforms={self.feature_transforms}, " | |
| f"waveform_transforms={self.waveform_transforms}, " | |
| f"dataset_transforms={self.dataset_transforms})" | |
| ) | |
| def is_lang_tag(cls, token): | |
| pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") | |
| return re.match(pattern, token) | |
| def check_tgt_lang_tag(self): | |
| if self.cfg.prepend_tgt_lang_tag: | |
| assert self.tgt_langs is not None and self.tgt_dict is not None | |
| tgt_lang_tags = [ | |
| self.LANG_TAG_TEMPLATE.format(t) for t in set(self.tgt_langs) | |
| ] | |
| assert all(t in self.tgt_dict for t in tgt_lang_tags) | |
| def tokenize(cls, tokenizer, text: str): | |
| return text if tokenizer is None else tokenizer.encode(text) | |
| def get_tokenized_tgt_text(self, index: Union[int, List[int]]): | |
| if _is_int_or_np_int(index): | |
| text = self.tgt_texts[index] | |
| else: | |
| text = " ".join([self.tgt_texts[i] for i in index]) | |
| text = self.tokenize(self.pre_tokenizer, text) | |
| text = self.tokenize(self.bpe_tokenizer, text) | |
| return text | |
| def pack_frames(self, feature: torch.Tensor): | |
| if self.n_frames_per_step == 1: | |
| return feature | |
| n_packed_frames = feature.shape[0] // self.n_frames_per_step | |
| feature = feature[: self.n_frames_per_step * n_packed_frames] | |
| return feature.reshape(n_packed_frames, -1) | |
| def get_lang_tag_idx(cls, lang: str, dictionary: Dictionary): | |
| lang_tag_idx = dictionary.index(cls.LANG_TAG_TEMPLATE.format(lang)) | |
| assert lang_tag_idx != dictionary.unk() | |
| return lang_tag_idx | |
| def _get_source_audio(self, index: Union[int, List[int]]) -> torch.Tensor: | |
| """ | |
| Gives source audio for given index with any relevant transforms | |
| applied. For ConcatAug, source audios for given indices are | |
| concatenated in given order. | |
| Args: | |
| index (int or List[int]): index—or in the case of ConcatAug, | |
| indices—to pull the source audio for | |
| Returns: | |
| source audios concatenated for given indices with | |
| relevant transforms appplied | |
| """ | |
| if _is_int_or_np_int(index): | |
| source = get_features_or_waveform( | |
| self.audio_paths[index], | |
| need_waveform=self.cfg.use_audio_input, | |
| use_sample_rate=self.cfg.use_sample_rate, | |
| waveform_transforms=self.waveform_transforms, | |
| ) | |
| else: | |
| source = np.concatenate( | |
| [ | |
| get_features_or_waveform( | |
| self.audio_paths[i], | |
| need_waveform=self.cfg.use_audio_input, | |
| use_sample_rate=self.cfg.use_sample_rate, | |
| waveform_transforms=self.waveform_transforms, | |
| ) | |
| for i in index | |
| ] | |
| ) | |
| if self.cfg.use_audio_input: | |
| source = torch.from_numpy(source).float() | |
| if self.cfg.standardize_audio: | |
| with torch.no_grad(): | |
| source = F.layer_norm(source, source.shape) | |
| else: | |
| if self.feature_transforms is not None: | |
| source = self.feature_transforms(source) | |
| source = torch.from_numpy(source).float() | |
| return source | |
| def __getitem__(self, index: int) -> SpeechToTextDatasetItem: | |
| has_concat = self.dataset_transforms.has_transform(ConcatAugment) | |
| if has_concat: | |
| concat = self.dataset_transforms.get_transform(ConcatAugment) | |
| indices = concat.find_indices(index, self.n_frames, self.n_samples) | |
| source = self._get_source_audio(indices if has_concat else index) | |
| source = self.pack_frames(source) | |
| target = None | |
| if self.tgt_texts is not None: | |
| tokenized = self.get_tokenized_tgt_text(indices if has_concat else index) | |
| target = self.tgt_dict.encode_line( | |
| tokenized, add_if_not_exist=False, append_eos=self.append_eos | |
| ).long() | |
| if self.cfg.prepend_tgt_lang_tag: | |
| lang_tag_idx = self.get_lang_tag_idx( | |
| self.tgt_langs[index], self.tgt_dict | |
| ) | |
| target = torch.cat((torch.LongTensor([lang_tag_idx]), target), 0) | |
| if self.cfg.prepend_bos_and_append_tgt_lang_tag: | |
| bos = torch.LongTensor([self.tgt_dict.bos()]) | |
| lang_tag_idx = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) | |
| assert lang_tag_idx != self.tgt_dict.unk() | |
| lang_tag_idx = torch.LongTensor([lang_tag_idx]) | |
| target = torch.cat((bos, target, lang_tag_idx), 0) | |
| speaker_id = None | |
| if self.speaker_to_id is not None: | |
| speaker_id = self.speaker_to_id[self.speakers[index]] | |
| return SpeechToTextDatasetItem( | |
| index=index, source=source, target=target, speaker_id=speaker_id | |
| ) | |
| def __len__(self): | |
| return self.n_samples | |
| def collater( | |
| self, samples: List[SpeechToTextDatasetItem], return_order: bool = False | |
| ) -> Dict: | |
| if len(samples) == 0: | |
| return {} | |
| indices = torch.tensor([x.index for x in samples], dtype=torch.long) | |
| sources = [x.source for x in samples] | |
| has_NOAug = self.dataset_transforms.has_transform(NoisyOverlapAugment) | |
| if has_NOAug and self.cfg.use_audio_input: | |
| NOAug = self.dataset_transforms.get_transform(NoisyOverlapAugment) | |
| sources = NOAug(sources) | |
| frames = _collate_frames(sources, self.cfg.use_audio_input) | |
| # sort samples by descending number of frames | |
| n_frames = torch.tensor([x.size(0) for x in sources], dtype=torch.long) | |
| n_frames, order = n_frames.sort(descending=True) | |
| indices = indices.index_select(0, order) | |
| frames = frames.index_select(0, order) | |
| target, target_lengths = None, None | |
| prev_output_tokens = None | |
| ntokens = None | |
| if self.tgt_texts is not None: | |
| target = fairseq_data_utils.collate_tokens( | |
| [x.target for x in samples], | |
| self.tgt_dict.pad(), | |
| self.tgt_dict.eos(), | |
| left_pad=False, | |
| move_eos_to_beginning=False, | |
| ) | |
| target = target.index_select(0, order) | |
| target_lengths = torch.tensor( | |
| [x.target.size(0) for x in samples], dtype=torch.long | |
| ).index_select(0, order) | |
| prev_output_tokens = fairseq_data_utils.collate_tokens( | |
| [x.target for x in samples], | |
| self.tgt_dict.pad(), | |
| eos_idx=None, | |
| left_pad=False, | |
| move_eos_to_beginning=True, | |
| ) | |
| prev_output_tokens = prev_output_tokens.index_select(0, order) | |
| ntokens = sum(x.target.size(0) for x in samples) | |
| speaker = None | |
| if self.speaker_to_id is not None: | |
| speaker = ( | |
| torch.tensor([s.speaker_id for s in samples], dtype=torch.long) | |
| .index_select(0, order) | |
| .view(-1, 1) | |
| ) | |
| net_input = { | |
| "src_tokens": frames, | |
| "src_lengths": n_frames, | |
| "prev_output_tokens": prev_output_tokens, | |
| } | |
| out = { | |
| "id": indices, | |
| "net_input": net_input, | |
| "speaker": speaker, | |
| "target": target, | |
| "target_lengths": target_lengths, | |
| "ntokens": ntokens, | |
| "nsentences": len(samples), | |
| } | |
| if return_order: | |
| out["order"] = order | |
| return out | |
| def num_tokens(self, index): | |
| return self.n_frames[index] | |
| def size(self, index): | |
| return self.n_frames[index], self.tgt_lens[index] | |
| def sizes(self): | |
| return np.array(self.n_frames) | |
| def can_reuse_epoch_itr_across_epochs(self): | |
| return True | |
| def ordered_indices(self): | |
| if self.shuffle: | |
| order = [np.random.permutation(len(self))] | |
| else: | |
| order = [np.arange(len(self))] | |
| # first by descending order of # of frames then by original/random order | |
| order.append([-n for n in self.n_frames]) | |
| return np.lexsort(order) | |
| def prefetch(self, indices): | |
| raise False | |
| class TextTargetMultitaskData(object): | |
| # mandatory columns | |
| KEY_ID, KEY_TEXT = "id", "tgt_text" | |
| LANG_TAG_TEMPLATE = "<lang:{}>" | |
| def __init__(self, args, split, tgt_dict): | |
| samples = SpeechToTextDatasetCreator._load_samples_from_tsv(args.data, split) | |
| self.data = {s[self.KEY_ID]: s[self.KEY_TEXT] for s in samples} | |
| self.dict = tgt_dict | |
| self.append_eos = args.decoder_type != "ctc" | |
| self.pre_tokenizer = self.build_tokenizer(args) | |
| self.bpe_tokenizer = self.build_bpe(args) | |
| self.prepend_bos_and_append_tgt_lang_tag = ( | |
| args.prepend_bos_and_append_tgt_lang_tag | |
| ) | |
| self.eos_token = args.eos_token | |
| self.lang_tag_mapping = args.get_lang_tag_mapping | |
| def is_lang_tag(cls, token): | |
| pattern = cls.LANG_TAG_TEMPLATE.replace("{}", "(.*)") | |
| return re.match(pattern, token) | |
| def tokenize(cls, tokenizer, text: str): | |
| return text if tokenizer is None else tokenizer.encode(text) | |
| def get_tokenized_tgt_text(self, index: int): | |
| text = self.tokenize(self.pre_tokenizer, self.data[index]) | |
| text = self.tokenize(self.bpe_tokenizer, text) | |
| return text | |
| def get_lang_tag_idx(self, lang: str, dictionary: Dictionary): | |
| lang_tag = self.LANG_TAG_TEMPLATE.format(lang) | |
| lang_tag = self.lang_tag_mapping.get(lang_tag, lang_tag) | |
| lang_tag_idx = dictionary.index(lang_tag) | |
| assert lang_tag_idx != dictionary.unk(), (lang, lang_tag) | |
| return lang_tag_idx | |
| def build_tokenizer(self, args): | |
| pre_tokenizer = args.config.get("pre_tokenizer") | |
| if pre_tokenizer is not None: | |
| logger.info(f"pre-tokenizer: {pre_tokenizer}") | |
| return encoders.build_tokenizer(Namespace(**pre_tokenizer)) | |
| else: | |
| return None | |
| def build_bpe(self, args): | |
| bpe_tokenizer = args.config.get("bpe_tokenizer") | |
| if bpe_tokenizer is not None: | |
| logger.info(f"tokenizer: {bpe_tokenizer}") | |
| return encoders.build_bpe(Namespace(**bpe_tokenizer)) | |
| else: | |
| return None | |
| def get(self, sample_id, tgt_lang=None): | |
| if sample_id in self.data: | |
| tokenized = self.get_tokenized_tgt_text(sample_id) | |
| target = self.dict.encode_line( | |
| tokenized, | |
| add_if_not_exist=False, | |
| append_eos=self.append_eos, | |
| ) | |
| if self.prepend_bos_and_append_tgt_lang_tag: | |
| bos = torch.LongTensor([self.dict.bos()]) | |
| lang_tag_idx = self.get_lang_tag_idx(tgt_lang, self.dict) | |
| assert lang_tag_idx != self.dict.unk() | |
| lang_tag_idx = torch.LongTensor([lang_tag_idx]) | |
| target = torch.cat((bos, target, lang_tag_idx), 0) | |
| return target | |
| else: | |
| logger.warning(f"no target for {sample_id}") | |
| return torch.IntTensor([]) | |
| def collater(self, samples: List[torch.Tensor]) -> torch.Tensor: | |
| out = fairseq_data_utils.collate_tokens( | |
| samples, | |
| self.dict.pad(), | |
| eos_idx=None, | |
| left_pad=False, | |
| move_eos_to_beginning=False, | |
| ).long() | |
| prev_out = fairseq_data_utils.collate_tokens( | |
| samples, | |
| self.dict.pad(), | |
| eos_idx=None, | |
| left_pad=False, | |
| move_eos_to_beginning=True, | |
| ).long() | |
| target_lengths = torch.tensor([t.size(0) for t in samples], dtype=torch.long) | |
| ntokens = sum(t.size(0) for t in samples) | |
| output = { | |
| "prev_output_tokens": prev_out, | |
| "target": out, | |
| "target_lengths": target_lengths, | |
| "ntokens": ntokens, | |
| } | |
| return output | |
| class SpeechToTextMultitaskDataset(SpeechToTextDataset): | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| self.multitask_data = {} | |
| def add_multitask_dataset(self, task_name, task_data): | |
| self.multitask_data[task_name] = task_data | |
| def __getitem__( | |
| self, index: int | |
| ) -> Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]: | |
| s2t_data = super().__getitem__(index) | |
| multitask_target = {} | |
| sample_id = self.ids[index] | |
| tgt_lang = self.tgt_langs[index] | |
| for task_name, task_dataset in self.multitask_data.items(): | |
| multitask_target[task_name] = task_dataset.get(sample_id, tgt_lang) | |
| return s2t_data, multitask_target | |
| def collater( | |
| self, samples: List[Tuple[SpeechToTextDatasetItem, Dict[str, torch.Tensor]]] | |
| ) -> Dict: | |
| if len(samples) == 0: | |
| return {} | |
| out = super().collater([s for s, _ in samples], return_order=True) | |
| order = out["order"] | |
| del out["order"] | |
| for task_name, task_dataset in self.multitask_data.items(): | |
| if "multitask" not in out: | |
| out["multitask"] = {} | |
| d = [s[task_name] for _, s in samples] | |
| task_target = task_dataset.collater(d) | |
| out["multitask"][task_name] = { | |
| "target": task_target["target"].index_select(0, order), | |
| "target_lengths": task_target["target_lengths"].index_select(0, order), | |
| "ntokens": task_target["ntokens"], | |
| } | |
| out["multitask"][task_name]["net_input"] = { | |
| "prev_output_tokens": task_target["prev_output_tokens"].index_select( | |
| 0, order | |
| ), | |
| } | |
| return out | |
| class SpeechToTextDatasetCreator(object): | |
| # mandatory columns | |
| KEY_ID, KEY_AUDIO, KEY_N_FRAMES = "id", "audio", "n_frames" | |
| KEY_TGT_TEXT = "tgt_text" | |
| # optional columns | |
| KEY_SPEAKER, KEY_SRC_TEXT = "speaker", "src_text" | |
| KEY_SRC_LANG, KEY_TGT_LANG = "src_lang", "tgt_lang" | |
| # default values | |
| DEFAULT_SPEAKER = DEFAULT_SRC_TEXT = DEFAULT_LANG = "" | |
| def _from_list( | |
| cls, | |
| split_name: str, | |
| is_train_split, | |
| samples: List[Dict], | |
| cfg: S2TDataConfig, | |
| tgt_dict, | |
| pre_tokenizer, | |
| bpe_tokenizer, | |
| n_frames_per_step, | |
| speaker_to_id, | |
| multitask: Optional[Dict] = None, | |
| ) -> SpeechToTextDataset: | |
| audio_root = Path(cfg.audio_root) | |
| ids = [s[cls.KEY_ID] for s in samples] | |
| audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] | |
| n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] | |
| tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] | |
| src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] | |
| speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] | |
| src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] | |
| tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] | |
| has_multitask = multitask is not None and len(multitask.keys()) > 0 | |
| dataset_cls = ( | |
| SpeechToTextMultitaskDataset if has_multitask else SpeechToTextDataset | |
| ) | |
| ds = dataset_cls( | |
| split=split_name, | |
| is_train_split=is_train_split, | |
| cfg=cfg, | |
| audio_paths=audio_paths, | |
| n_frames=n_frames, | |
| src_texts=src_texts, | |
| tgt_texts=tgt_texts, | |
| speakers=speakers, | |
| src_langs=src_langs, | |
| tgt_langs=tgt_langs, | |
| ids=ids, | |
| tgt_dict=tgt_dict, | |
| pre_tokenizer=pre_tokenizer, | |
| bpe_tokenizer=bpe_tokenizer, | |
| n_frames_per_step=n_frames_per_step, | |
| speaker_to_id=speaker_to_id, | |
| ) | |
| if has_multitask: | |
| for task_name, task_obj in multitask.items(): | |
| task_data = TextTargetMultitaskData( | |
| task_obj.args, split_name, task_obj.target_dictionary | |
| ) | |
| ds.add_multitask_dataset(task_name, task_data) | |
| return ds | |
| def get_size_ratios( | |
| cls, datasets: List[SpeechToTextDataset], alpha: float = 1.0 | |
| ) -> List[float]: | |
| """Size ratios for temperature-based sampling | |
| (https://arxiv.org/abs/1907.05019)""" | |
| id_to_lp, lp_to_sz = {}, defaultdict(int) | |
| for ds in datasets: | |
| lang_pairs = {f"{s}->{t}" for s, t in zip(ds.src_langs, ds.tgt_langs)} | |
| assert len(lang_pairs) == 1 | |
| lang_pair = list(lang_pairs)[0] | |
| id_to_lp[ds.split] = lang_pair | |
| lp_to_sz[lang_pair] += sum(ds.n_frames) | |
| sz_sum = sum(v for v in lp_to_sz.values()) | |
| lp_to_prob = {k: v / sz_sum for k, v in lp_to_sz.items()} | |
| lp_to_tgt_prob = {k: v**alpha for k, v in lp_to_prob.items()} | |
| prob_sum = sum(v for v in lp_to_tgt_prob.values()) | |
| lp_to_tgt_prob = {k: v / prob_sum for k, v in lp_to_tgt_prob.items()} | |
| lp_to_sz_ratio = { | |
| k: (lp_to_tgt_prob[k] * sz_sum) / v for k, v in lp_to_sz.items() | |
| } | |
| size_ratio = [lp_to_sz_ratio[id_to_lp[ds.split]] for ds in datasets] | |
| p_formatted = { | |
| k: f"{lp_to_prob[k]:.3f}->{lp_to_tgt_prob[k]:.3f}" for k in lp_to_sz | |
| } | |
| logger.info(f"sampling probability balancing: {p_formatted}") | |
| sr_formatted = {ds.split: f"{r:.3f}" for ds, r in zip(datasets, size_ratio)} | |
| logger.info(f"balanced sampling size ratio: {sr_formatted}") | |
| return size_ratio | |
| def _load_samples_from_tsv(cls, root: str, split: str): | |
| tsv_path = Path(root) / f"{split}.tsv" | |
| if not tsv_path.is_file(): | |
| raise FileNotFoundError(f"Dataset not found: {tsv_path}") | |
| with open(tsv_path) as f: | |
| reader = csv.DictReader( | |
| f, | |
| delimiter="\t", | |
| quotechar=None, | |
| doublequote=False, | |
| lineterminator="\n", | |
| quoting=csv.QUOTE_NONE, | |
| ) | |
| samples = [dict(e) for e in reader] | |
| if len(samples) == 0: | |
| raise ValueError(f"Empty manifest: {tsv_path}") | |
| return samples | |
| def _from_tsv( | |
| cls, | |
| root: str, | |
| cfg: S2TDataConfig, | |
| split: str, | |
| tgt_dict, | |
| is_train_split: bool, | |
| pre_tokenizer, | |
| bpe_tokenizer, | |
| n_frames_per_step, | |
| speaker_to_id, | |
| multitask: Optional[Dict] = None, | |
| ) -> SpeechToTextDataset: | |
| samples = cls._load_samples_from_tsv(root, split) | |
| return cls._from_list( | |
| split, | |
| is_train_split, | |
| samples, | |
| cfg, | |
| tgt_dict, | |
| pre_tokenizer, | |
| bpe_tokenizer, | |
| n_frames_per_step, | |
| speaker_to_id, | |
| multitask, | |
| ) | |
| def from_tsv( | |
| cls, | |
| root: str, | |
| cfg: S2TDataConfig, | |
| splits: str, | |
| tgt_dict, | |
| pre_tokenizer, | |
| bpe_tokenizer, | |
| is_train_split: bool, | |
| epoch: int, | |
| seed: int, | |
| n_frames_per_step: int = 1, | |
| speaker_to_id=None, | |
| multitask: Optional[Dict] = None, | |
| ) -> SpeechToTextDataset: | |
| datasets = [ | |
| cls._from_tsv( | |
| root=root, | |
| cfg=cfg, | |
| split=split, | |
| tgt_dict=tgt_dict, | |
| is_train_split=is_train_split, | |
| pre_tokenizer=pre_tokenizer, | |
| bpe_tokenizer=bpe_tokenizer, | |
| n_frames_per_step=n_frames_per_step, | |
| speaker_to_id=speaker_to_id, | |
| multitask=multitask, | |
| ) | |
| for split in splits.split(",") | |
| ] | |
| if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: | |
| # temperature-based sampling | |
| size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) | |
| datasets = [ | |
| ResamplingDataset( | |
| d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) | |
| ) | |
| for r, d in zip(size_ratios, datasets) | |
| ] | |
| return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] | |