| |
| |
| |
| |
|
|
| import logging |
| import os.path as op |
| from argparse import Namespace |
|
|
| from fairseq.data import Dictionary, encoders |
| from fairseq.data.audio.speech_to_text_dataset import ( |
| S2TDataConfig, |
| SpeechToTextDataset, |
| SpeechToTextDatasetCreator, |
| ) |
| from fairseq.tasks import FairseqTask, register_task |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @register_task("speech_to_text") |
| class SpeechToTextTask(FairseqTask): |
| @staticmethod |
| def add_args(parser): |
| parser.add_argument("data", help="manifest root path") |
| parser.add_argument( |
| "--config-yaml", |
| type=str, |
| default="config.yaml", |
| help="Configuration YAML filename (under manifest root)", |
| ) |
| parser.add_argument( |
| "--max-source-positions", |
| default=6000, |
| type=int, |
| metavar="N", |
| help="max number of tokens in the source sequence", |
| ) |
| parser.add_argument( |
| "--max-target-positions", |
| default=1024, |
| type=int, |
| metavar="N", |
| help="max number of tokens in the target sequence", |
| ) |
|
|
| def __init__(self, args, tgt_dict): |
| super().__init__(args) |
| self.tgt_dict = tgt_dict |
| self.data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) |
|
|
| @classmethod |
| def setup_task(cls, args, **kwargs): |
| data_cfg = S2TDataConfig(op.join(args.data, args.config_yaml)) |
| dict_path = op.join(args.data, data_cfg.vocab_filename) |
| if not op.isfile(dict_path): |
| raise FileNotFoundError(f"Dict not found: {dict_path}") |
| tgt_dict = Dictionary.load(dict_path) |
| logger.info( |
| f"dictionary size ({data_cfg.vocab_filename}): " f"{len(tgt_dict):,}" |
| ) |
|
|
| if getattr(args, "train_subset", None) is not None: |
| if not all(s.startswith("train") for s in args.train_subset.split(",")): |
| raise ValueError('Train splits should be named like "train*".') |
| return cls(args, tgt_dict) |
|
|
| def build_criterion(self, args): |
| from fairseq import criterions |
|
|
| if self.data_cfg.prepend_tgt_lang_tag and args.ignore_prefix_size != 1: |
| raise ValueError( |
| 'Please set "--ignore-prefix-size 1" since ' |
| "target language ID token is prepended as BOS." |
| ) |
| return criterions.build_criterion(args, self) |
|
|
| def load_dataset(self, split, epoch=1, combine=False, **kwargs): |
| is_train_split = split.startswith("train") |
| pre_tokenizer = self.build_tokenizer(self.args) |
| bpe_tokenizer = self.build_bpe(self.args) |
| self.datasets[split] = SpeechToTextDatasetCreator.from_tsv( |
| self.args.data, |
| self.data_cfg, |
| split, |
| self.tgt_dict, |
| pre_tokenizer, |
| bpe_tokenizer, |
| is_train_split=is_train_split, |
| epoch=epoch, |
| seed=self.args.seed, |
| ) |
|
|
| @property |
| def target_dictionary(self): |
| return self.tgt_dict |
|
|
| @property |
| def source_dictionary(self): |
| return None |
|
|
| def max_positions(self): |
| return self.args.max_source_positions, self.args.max_target_positions |
|
|
| def build_model(self, args): |
| args.input_feat_per_channel = self.data_cfg.input_feat_per_channel |
| args.input_channels = self.data_cfg.input_channels |
| return super(SpeechToTextTask, self).build_model(args) |
|
|
| def build_generator( |
| self, |
| models, |
| args, |
| seq_gen_cls=None, |
| extra_gen_cls_kwargs=None, |
| ): |
| if self.data_cfg.prepend_tgt_lang_tag and args.prefix_size != 1: |
| raise ValueError( |
| 'Please set "--prefix-size 1" since ' |
| "target language ID token is prepended as BOS." |
| ) |
| lang_token_ids = { |
| i |
| for s, i in self.tgt_dict.indices.items() |
| if SpeechToTextDataset.is_lang_tag(s) |
| } |
| extra_gen_cls_kwargs = {"symbols_to_strip_from_output": lang_token_ids} |
| return super().build_generator( |
| models, args, seq_gen_cls=None, extra_gen_cls_kwargs=extra_gen_cls_kwargs |
| ) |
|
|
| def build_tokenizer(self, args): |
| logger.info(f"pre-tokenizer: {self.data_cfg.pre_tokenizer}") |
| return encoders.build_tokenizer(Namespace(**self.data_cfg.pre_tokenizer)) |
|
|
| def build_bpe(self, args): |
| logger.info(f"tokenizer: {self.data_cfg.bpe_tokenizer}") |
| return encoders.build_bpe(Namespace(**self.data_cfg.bpe_tokenizer)) |
|
|
| @classmethod |
| def build_dataset_for_inference(cls, audio_paths, n_frames): |
| return SpeechToTextDataset("interactive", False, {}, audio_paths, n_frames) |
|
|