|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, NamedTuple, Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from fairseq.data import ConcatDataset, Dictionary, ResamplingDataset |
|
|
from fairseq.data import data_utils as fairseq_data_utils |
|
|
from fairseq.data.audio.speech_to_text_dataset import ( |
|
|
S2TDataConfig, |
|
|
SpeechToTextDataset, |
|
|
SpeechToTextDatasetCreator, |
|
|
) |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class S2TJointDataConfig(S2TDataConfig): |
|
|
"""Wrapper class for data config YAML""" |
|
|
|
|
|
@property |
|
|
def src_vocab_filename(self): |
|
|
"""fairseq vocabulary file under data root""" |
|
|
return self.config.get("src_vocab_filename", "src_dict.txt") |
|
|
|
|
|
@property |
|
|
def src_pre_tokenizer(self) -> Dict: |
|
|
"""Pre-tokenizer to apply before subword tokenization. Returning |
|
|
a dictionary with `tokenizer` providing the tokenizer name and |
|
|
the other items providing the tokenizer-specific arguments. |
|
|
Tokenizers are defined in `fairseq.data.encoders.*`""" |
|
|
return self.config.get("src_pre_tokenizer", {"tokenizer": None}) |
|
|
|
|
|
@property |
|
|
def src_bpe_tokenizer(self) -> Dict: |
|
|
"""Subword tokenizer to apply on source text after pre-tokenization. |
|
|
Returning a dictionary with `bpe` providing the tokenizer name and |
|
|
the other items providing the tokenizer-specific arguments. |
|
|
Tokenizers are defined in `fairseq.data.encoders.*`""" |
|
|
return self.config.get("src_bpe_tokenizer", {"bpe": None}) |
|
|
|
|
|
@property |
|
|
def prepend_tgt_lang_tag_no_change(self) -> bool: |
|
|
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for |
|
|
to-many multilingual setting). No change needed during inference. |
|
|
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos. |
|
|
""" |
|
|
value = self.config.get("prepend_tgt_lang_tag_no_change", None) |
|
|
if value is None: |
|
|
return self.config.get("prepend_tgt_lang_tag_as_bos", False) |
|
|
return value |
|
|
|
|
|
@property |
|
|
def sampling_text_alpha(self): |
|
|
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text |
|
|
input only) (alpha = 1 for no resampling)""" |
|
|
return self.config.get("sampling_text_alpha", 1.0) |
|
|
|
|
|
|
|
|
class SpeechToTextJointDatasetItem(NamedTuple): |
|
|
index: int |
|
|
source: torch.Tensor |
|
|
target: Optional[torch.Tensor] = None |
|
|
src_txt_tokens: Optional[torch.Tensor] = None |
|
|
tgt_lang_tag: Optional[int] = None |
|
|
src_lang_tag: Optional[int] = None |
|
|
tgt_alignment: Optional[torch.Tensor] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SpeechToTextJointDataset(SpeechToTextDataset): |
|
|
def __init__( |
|
|
self, |
|
|
split: str, |
|
|
is_train_split: bool, |
|
|
cfg: S2TJointDataConfig, |
|
|
audio_paths: List[str], |
|
|
n_frames: List[int], |
|
|
src_texts: Optional[List[str]] = None, |
|
|
tgt_texts: Optional[List[str]] = None, |
|
|
speakers: Optional[List[str]] = None, |
|
|
src_langs: Optional[List[str]] = None, |
|
|
tgt_langs: Optional[List[str]] = None, |
|
|
ids: Optional[List[str]] = None, |
|
|
tgt_dict: Optional[Dictionary] = None, |
|
|
src_dict: Optional[Dictionary] = None, |
|
|
pre_tokenizer=None, |
|
|
bpe_tokenizer=None, |
|
|
src_pre_tokenizer=None, |
|
|
src_bpe_tokenizer=None, |
|
|
append_eos: Optional[bool] = True, |
|
|
alignment: Optional[List[str]] = None, |
|
|
use_src_lang_id: Optional[int] = 0, |
|
|
): |
|
|
super().__init__( |
|
|
split, |
|
|
is_train_split, |
|
|
cfg, |
|
|
audio_paths, |
|
|
n_frames, |
|
|
src_texts=src_texts, |
|
|
tgt_texts=tgt_texts, |
|
|
speakers=speakers, |
|
|
src_langs=src_langs, |
|
|
tgt_langs=tgt_langs, |
|
|
ids=ids, |
|
|
tgt_dict=tgt_dict, |
|
|
pre_tokenizer=pre_tokenizer, |
|
|
bpe_tokenizer=bpe_tokenizer, |
|
|
append_eos=append_eos, |
|
|
) |
|
|
|
|
|
self.src_dict = src_dict |
|
|
self.src_pre_tokenizer = src_pre_tokenizer |
|
|
self.src_bpe_tokenizer = src_bpe_tokenizer |
|
|
self.alignment = None |
|
|
self.use_src_lang_id = use_src_lang_id |
|
|
if alignment is not None: |
|
|
self.alignment = [ |
|
|
[float(s) for s in sample.split()] for sample in alignment |
|
|
] |
|
|
|
|
|
def get_tokenized_src_text(self, index: int): |
|
|
text = self.tokenize(self.src_pre_tokenizer, self.src_texts[index]) |
|
|
text = self.tokenize(self.src_bpe_tokenizer, text) |
|
|
return text |
|
|
|
|
|
def __getitem__(self, index: int) -> SpeechToTextJointDatasetItem: |
|
|
s2t_dataset_item = super().__getitem__(index) |
|
|
src_tokens = None |
|
|
src_lang_tag = None |
|
|
if self.src_texts is not None and self.src_dict is not None: |
|
|
src_tokens = self.get_tokenized_src_text(index) |
|
|
src_tokens = self.src_dict.encode_line( |
|
|
src_tokens, add_if_not_exist=False, append_eos=True |
|
|
).long() |
|
|
if self.use_src_lang_id > 0: |
|
|
src_lang_tag = self.get_lang_tag_idx( |
|
|
self.src_langs[index], self.src_dict |
|
|
) |
|
|
tgt_lang_tag = None |
|
|
if self.cfg.prepend_tgt_lang_tag_no_change: |
|
|
|
|
|
tgt_lang_tag = self.get_lang_tag_idx(self.tgt_langs[index], self.tgt_dict) |
|
|
ali = None |
|
|
if self.alignment is not None: |
|
|
ali = torch.Tensor(self.alignment[index]).float() |
|
|
|
|
|
return SpeechToTextJointDatasetItem( |
|
|
index=index, |
|
|
source=s2t_dataset_item.source, |
|
|
target=s2t_dataset_item.target, |
|
|
src_txt_tokens=src_tokens, |
|
|
tgt_lang_tag=tgt_lang_tag, |
|
|
src_lang_tag=src_lang_tag, |
|
|
tgt_alignment=ali, |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
return self.n_samples |
|
|
|
|
|
def collater(self, samples: List[SpeechToTextJointDatasetItem]) -> Dict: |
|
|
s2t_out = super().collater(samples, return_order=True) |
|
|
if s2t_out == {}: |
|
|
return s2t_out |
|
|
net_input, order = s2t_out["net_input"], s2t_out["order"] |
|
|
|
|
|
if self.src_texts is not None and self.src_dict is not None: |
|
|
src_txt_tokens = fairseq_data_utils.collate_tokens( |
|
|
[x.src_txt_tokens for x in samples], |
|
|
self.src_dict.pad(), |
|
|
self.src_dict.eos(), |
|
|
left_pad=False, |
|
|
move_eos_to_beginning=False, |
|
|
) |
|
|
src_txt_lengths = torch.tensor( |
|
|
[x.src_txt_tokens.size()[0] for x in samples], dtype=torch.long |
|
|
) |
|
|
if self.use_src_lang_id > 0: |
|
|
src_lang_idxs = torch.tensor( |
|
|
[s.src_lang_tag for s in samples], dtype=src_txt_tokens.dtype |
|
|
) |
|
|
if self.use_src_lang_id == 1: |
|
|
eos_idx = src_txt_lengths - 1 |
|
|
src_txt_tokens.scatter_( |
|
|
1, eos_idx.view(-1, 1), src_lang_idxs.view(-1, 1) |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError("Implementation is required") |
|
|
|
|
|
src_txt_tokens = src_txt_tokens.index_select(0, order) |
|
|
src_txt_lengths = src_txt_lengths.index_select(0, order) |
|
|
net_input["src_txt_tokens"] = src_txt_tokens |
|
|
net_input["src_txt_lengths"] = src_txt_lengths |
|
|
|
|
|
net_input["alignment"] = None |
|
|
if self.alignment is not None: |
|
|
max_len = max([s.tgt_alignment.size(0) for s in samples]) |
|
|
alignment = torch.ones(len(samples), max_len).float() |
|
|
for i, s in enumerate(samples): |
|
|
cur_len = s.tgt_alignment.size(0) |
|
|
alignment[i][:cur_len].copy_(s.tgt_alignment) |
|
|
net_input["alignment"] = alignment.index_select(0, order) |
|
|
|
|
|
if self.tgt_texts is not None and samples[0].tgt_lang_tag is not None: |
|
|
for i in range(len(samples)): |
|
|
net_input["prev_output_tokens"][i][0] = samples[order[i]].tgt_lang_tag |
|
|
|
|
|
out = { |
|
|
"id": s2t_out["id"], |
|
|
"net_input": net_input, |
|
|
"target": s2t_out["target"], |
|
|
"target_lengths": s2t_out["target_lengths"], |
|
|
"ntokens": s2t_out["ntokens"], |
|
|
"nsentences": len(samples), |
|
|
} |
|
|
return out |
|
|
|
|
|
|
|
|
class SpeechToTextJointDatasetCreator(SpeechToTextDatasetCreator): |
|
|
KEY_ALIGN = "align" |
|
|
|
|
|
@classmethod |
|
|
def _from_list( |
|
|
cls, |
|
|
split_name: str, |
|
|
is_train_split, |
|
|
samples: List[Dict], |
|
|
cfg: S2TJointDataConfig, |
|
|
tgt_dict, |
|
|
src_dict, |
|
|
pre_tokenizer, |
|
|
bpe_tokenizer, |
|
|
src_pre_tokenizer, |
|
|
src_bpe_tokenizer, |
|
|
append_eos, |
|
|
use_src_lang_id, |
|
|
) -> SpeechToTextJointDataset: |
|
|
audio_root = Path(cfg.audio_root) |
|
|
ids = [s[cls.KEY_ID] for s in samples] |
|
|
audio_paths = [(audio_root / s[cls.KEY_AUDIO]).as_posix() for s in samples] |
|
|
n_frames = [int(s[cls.KEY_N_FRAMES]) for s in samples] |
|
|
tgt_texts = [s[cls.KEY_TGT_TEXT] for s in samples] |
|
|
src_texts = [s.get(cls.KEY_SRC_TEXT, cls.DEFAULT_SRC_TEXT) for s in samples] |
|
|
speakers = [s.get(cls.KEY_SPEAKER, cls.DEFAULT_SPEAKER) for s in samples] |
|
|
src_langs = [s.get(cls.KEY_SRC_LANG, cls.DEFAULT_LANG) for s in samples] |
|
|
tgt_langs = [s.get(cls.KEY_TGT_LANG, cls.DEFAULT_LANG) for s in samples] |
|
|
tgt_alignment = None |
|
|
if cls.KEY_ALIGN in samples[0].keys(): |
|
|
tgt_alignment = [s[cls.KEY_ALIGN] for s in samples] |
|
|
return SpeechToTextJointDataset( |
|
|
split_name, |
|
|
is_train_split, |
|
|
cfg, |
|
|
audio_paths, |
|
|
n_frames, |
|
|
src_texts=src_texts, |
|
|
tgt_texts=tgt_texts, |
|
|
speakers=speakers, |
|
|
src_langs=src_langs, |
|
|
tgt_langs=tgt_langs, |
|
|
ids=ids, |
|
|
tgt_dict=tgt_dict, |
|
|
src_dict=src_dict, |
|
|
pre_tokenizer=pre_tokenizer, |
|
|
bpe_tokenizer=bpe_tokenizer, |
|
|
src_pre_tokenizer=src_pre_tokenizer, |
|
|
src_bpe_tokenizer=src_bpe_tokenizer, |
|
|
append_eos=append_eos, |
|
|
alignment=tgt_alignment, |
|
|
use_src_lang_id=use_src_lang_id, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def _from_tsv( |
|
|
cls, |
|
|
root: str, |
|
|
cfg: S2TJointDataConfig, |
|
|
split: str, |
|
|
tgt_dict, |
|
|
src_dict, |
|
|
is_train_split: bool, |
|
|
pre_tokenizer, |
|
|
bpe_tokenizer, |
|
|
src_pre_tokenizer, |
|
|
src_bpe_tokenizer, |
|
|
append_eos: bool, |
|
|
use_src_lang_id: int, |
|
|
) -> SpeechToTextJointDataset: |
|
|
samples = cls._load_samples_from_tsv(root, split) |
|
|
return cls._from_list( |
|
|
split, |
|
|
is_train_split, |
|
|
samples, |
|
|
cfg, |
|
|
tgt_dict, |
|
|
src_dict, |
|
|
pre_tokenizer, |
|
|
bpe_tokenizer, |
|
|
src_pre_tokenizer, |
|
|
src_bpe_tokenizer, |
|
|
append_eos, |
|
|
use_src_lang_id, |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_tsv( |
|
|
cls, |
|
|
root: str, |
|
|
cfg: S2TJointDataConfig, |
|
|
splits: str, |
|
|
tgt_dict, |
|
|
src_dict, |
|
|
pre_tokenizer, |
|
|
bpe_tokenizer, |
|
|
src_pre_tokenizer, |
|
|
src_bpe_tokenizer, |
|
|
is_train_split: bool, |
|
|
epoch: int, |
|
|
seed: int, |
|
|
append_eos: Optional[bool] = True, |
|
|
use_src_lang_id: Optional[int] = 0, |
|
|
) -> SpeechToTextJointDataset: |
|
|
datasets = [ |
|
|
cls._from_tsv( |
|
|
root, |
|
|
cfg, |
|
|
split, |
|
|
tgt_dict, |
|
|
src_dict, |
|
|
is_train_split, |
|
|
pre_tokenizer, |
|
|
bpe_tokenizer, |
|
|
src_pre_tokenizer, |
|
|
src_bpe_tokenizer, |
|
|
append_eos=append_eos, |
|
|
use_src_lang_id=use_src_lang_id, |
|
|
) |
|
|
for split in splits.split(",") |
|
|
] |
|
|
|
|
|
if is_train_split and len(datasets) > 1 and cfg.sampling_alpha != 1.0: |
|
|
|
|
|
size_ratios = cls.get_size_ratios(datasets, alpha=cfg.sampling_alpha) |
|
|
datasets = [ |
|
|
ResamplingDataset( |
|
|
d, size_ratio=r, seed=seed, epoch=epoch, replace=(r >= 1.0) |
|
|
) |
|
|
for r, d in zip(size_ratios, datasets) |
|
|
] |
|
|
|
|
|
return ConcatDataset(datasets) if len(datasets) > 1 else datasets[0] |
|
|
|