Spaces:
Sleeping
Sleeping
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| from argparse import Namespace | |
| from copy import deepcopy | |
| from pathlib import Path | |
| from typing import Dict, Optional | |
| from fairseq.data import Dictionary | |
| logger = logging.getLogger(__name__) | |
| def get_config_from_yaml(yaml_path: Path): | |
| try: | |
| import yaml | |
| except ImportError: | |
| print("Please install PyYAML: pip install PyYAML") | |
| config = {} | |
| if yaml_path.is_file(): | |
| try: | |
| with open(yaml_path) as f: | |
| config = yaml.load(f, Loader=yaml.FullLoader) | |
| except Exception as e: | |
| raise Exception(f"Failed to load config from {yaml_path.as_posix()}: {e}") | |
| else: | |
| raise FileNotFoundError(f"{yaml_path.as_posix()} not found") | |
| return config | |
| class S2TDataConfig(object): | |
| """Wrapper class for data config YAML""" | |
| def __init__(self, yaml_path: Path): | |
| self.config = get_config_from_yaml(yaml_path) | |
| self.root = yaml_path.parent | |
| def _auto_convert_to_abs_path(self, x): | |
| if isinstance(x, str): | |
| if not Path(x).exists() and (self.root / x).exists(): | |
| return (self.root / x).as_posix() | |
| elif isinstance(x, dict): | |
| return {k: self._auto_convert_to_abs_path(v) for k, v in x.items()} | |
| return x | |
| def vocab_filename(self): | |
| """fairseq vocabulary file under data root""" | |
| return self.config.get("vocab_filename", "dict.txt") | |
| def speaker_set_filename(self): | |
| """speaker set file under data root""" | |
| return self.config.get("speaker_set_filename", None) | |
| def shuffle(self) -> bool: | |
| """Shuffle dataset samples before batching""" | |
| return self.config.get("shuffle", False) | |
| def pre_tokenizer(self) -> Dict: | |
| """Pre-tokenizer to apply before subword tokenization. Returning | |
| a dictionary with `tokenizer` providing the tokenizer name and | |
| the other items providing the tokenizer-specific arguments. | |
| Tokenizers are defined in `fairseq.data.encoders.*`""" | |
| tokenizer = self.config.get("pre_tokenizer", {"tokenizer": None}) | |
| return self._auto_convert_to_abs_path(tokenizer) | |
| def bpe_tokenizer(self) -> Dict: | |
| """Subword tokenizer to apply after pre-tokenization. Returning | |
| a dictionary with `bpe` providing the tokenizer name and | |
| the other items providing the tokenizer-specific arguments. | |
| Tokenizers are defined in `fairseq.data.encoders.*`""" | |
| tokenizer = self.config.get("bpe_tokenizer", {"bpe": None}) | |
| return self._auto_convert_to_abs_path(tokenizer) | |
| def prepend_tgt_lang_tag(self) -> bool: | |
| """Prepend target lang ID token as the target BOS (e.g. for to-many | |
| multilingual setting). During inference, this requires `--prefix-size 1` | |
| to force BOS to be lang ID token.""" | |
| return self.config.get("prepend_tgt_lang_tag", False) | |
| def prepend_bos_and_append_tgt_lang_tag(self) -> bool: | |
| """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" | |
| return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) | |
| def input_feat_per_channel(self): | |
| """The dimension of input features (per audio channel)""" | |
| return self.config.get("input_feat_per_channel", 80) | |
| def input_channels(self): | |
| """The number of channels in the input audio""" | |
| return self.config.get("input_channels", 1) | |
| def sample_rate(self): | |
| return self.config.get("sample_rate", 16_000) | |
| def sampling_alpha(self): | |
| """Hyper-parameter alpha = 1/T for temperature-based resampling. | |
| (alpha = 1 for no resampling)""" | |
| return self.config.get("sampling_alpha", 1.0) | |
| def use_audio_input(self): | |
| """Needed by the dataset loader to see if the model requires | |
| raw audio as inputs.""" | |
| return self.config.get("use_audio_input", False) | |
| def standardize_audio(self) -> bool: | |
| return self.use_audio_input and self.config.get("standardize_audio", False) | |
| def use_sample_rate(self): | |
| """Needed by the dataset loader to see if the model requires | |
| raw audio with specific sample rate as inputs.""" | |
| return self.config.get("use_sample_rate", 16000) | |
| def audio_root(self): | |
| """Audio paths in the manifest TSV can be relative and this provides | |
| the root path. Set this to empty string when using absolute paths.""" | |
| return self.config.get("audio_root", "") | |
| def get_transforms(self, transform_type, split, is_train): | |
| """Split-specific feature transforms. Allowing train set | |
| wildcard `_train`, evaluation set wildcard `_eval` and general | |
| wildcard `*` for matching.""" | |
| from copy import deepcopy | |
| cfg = deepcopy(self.config) | |
| _cur = cfg.get(f"{transform_type}transforms", {}) | |
| cur = _cur.get(split) | |
| cur = _cur.get("_train") if cur is None and is_train else cur | |
| cur = _cur.get("_eval") if cur is None and not is_train else cur | |
| cur = _cur.get("*") if cur is None else cur | |
| return cur | |
| def get_feature_transforms(self, split, is_train): | |
| cfg = deepcopy(self.config) | |
| # TODO: deprecate transforms | |
| cur = self.get_transforms("", split, is_train) | |
| if cur is not None: | |
| logger.warning( | |
| "Auto converting transforms into feature_transforms, " | |
| "but transforms will be deprecated in the future. Please " | |
| "update this in the config." | |
| ) | |
| ft_transforms = self.get_transforms("feature_", split, is_train) | |
| if ft_transforms: | |
| cur.extend(ft_transforms) | |
| else: | |
| cur = self.get_transforms("feature_", split, is_train) | |
| cfg["feature_transforms"] = cur | |
| return cfg | |
| def get_waveform_transforms(self, split, is_train): | |
| cfg = deepcopy(self.config) | |
| cfg["waveform_transforms"] = self.get_transforms("waveform_", split, is_train) | |
| return cfg | |
| def get_dataset_transforms(self, split, is_train): | |
| cfg = deepcopy(self.config) | |
| cfg["dataset_transforms"] = self.get_transforms("dataset_", split, is_train) | |
| return cfg | |
| def global_cmvn_stats_npz(self) -> Optional[str]: | |
| path = self.config.get("global_cmvn", {}).get("stats_npz_path", None) | |
| return self._auto_convert_to_abs_path(path) | |
| def vocoder(self) -> Dict[str, str]: | |
| vocoder = self.config.get("vocoder", {"type": "griffin_lim"}) | |
| return self._auto_convert_to_abs_path(vocoder) | |
| def hub(self) -> Dict[str, str]: | |
| return self.config.get("hub", {}) | |
| class S2SDataConfig(S2TDataConfig): | |
| """Wrapper class for data config YAML""" | |
| def vocab_filename(self): | |
| """fairseq vocabulary file under data root""" | |
| return self.config.get("vocab_filename", None) | |
| def pre_tokenizer(self) -> Dict: | |
| return None | |
| def bpe_tokenizer(self) -> Dict: | |
| return None | |
| def input_transformed_channels(self): | |
| """The number of channels in the audio after feature transforms""" | |
| # TODO: move this into individual transforms | |
| # TODO: deprecate transforms | |
| _cur = self.config.get("transforms", {}) | |
| ft_transforms = self.config.get("feature_transforms", {}) | |
| if _cur and ft_transforms: | |
| _cur.update(ft_transforms) | |
| else: | |
| _cur = self.config.get("feature_transforms", {}) | |
| cur = _cur.get("_train", []) | |
| _channels = self.input_channels | |
| if "delta_deltas" in cur: | |
| _channels *= 3 | |
| return _channels | |
| def output_sample_rate(self): | |
| """The audio sample rate of output target speech""" | |
| return self.config.get("output_sample_rate", 22050) | |
| def target_speaker_embed(self): | |
| """Target speaker embedding file (one line per target audio sample)""" | |
| return self.config.get("target_speaker_embed", None) | |
| def prepend_tgt_lang_tag_as_bos(self) -> bool: | |
| """Prepend target lang ID token as the target BOS.""" | |
| return self.config.get("prepend_tgt_lang_tag_as_bos", False) | |
| class MultitaskConfig(object): | |
| """Wrapper class for data config YAML""" | |
| def __init__(self, yaml_path: Path): | |
| config = get_config_from_yaml(yaml_path) | |
| self.config = {} | |
| for k, v in config.items(): | |
| self.config[k] = SingleTaskConfig(k, v) | |
| def get_all_tasks(self): | |
| return self.config | |
| def get_single_task(self, name): | |
| assert name in self.config, f"multitask '{name}' does not exist!" | |
| return self.config[name] | |
| def first_pass_decoder_task_index(self): | |
| """Return the task index of the first-pass text decoder. | |
| If there are multiple 'is_first_pass_decoder: True' in the config file, | |
| the last task is used for the first-pass decoder. | |
| If there is no 'is_first_pass_decoder: True' in the config file, | |
| the last task whose task_name includes 'target' and decoder_type is not ctc. | |
| """ | |
| idx = -1 | |
| for i, (k, v) in enumerate(self.config.items()): | |
| if v.is_first_pass_decoder: | |
| idx = i | |
| if idx < 0: | |
| for i, (k, v) in enumerate(self.config.items()): | |
| if k.startswith("target") and v.decoder_type == "transformer": | |
| idx = i | |
| return idx | |
| class SingleTaskConfig(object): | |
| def __init__(self, name, config): | |
| self.task_name = name | |
| self.config = config | |
| dict_path = config.get("dict", "") | |
| self.tgt_dict = Dictionary.load(dict_path) if Path(dict_path).exists() else None | |
| def data(self): | |
| return self.config.get("data", "") | |
| def decoder_type(self): | |
| return self.config.get("decoder_type", "transformer") | |
| def decoder_args(self): | |
| """Decoder arch related args""" | |
| args = self.config.get("decoder_args", {}) | |
| return Namespace(**args) | |
| def criterion_cfg(self): | |
| """cfg for the multitask criterion""" | |
| if self.decoder_type == "ctc": | |
| from fairseq.criterions.ctc import CtcCriterionConfig | |
| cfg = CtcCriterionConfig | |
| cfg.zero_infinity = self.config.get("zero_infinity", True) | |
| else: | |
| from fairseq.criterions.label_smoothed_cross_entropy import ( | |
| LabelSmoothedCrossEntropyCriterionConfig, | |
| ) | |
| cfg = LabelSmoothedCrossEntropyCriterionConfig | |
| cfg.label_smoothing = self.config.get("label_smoothing", 0.2) | |
| return cfg | |
| def input_from(self): | |
| """Condition on encoder/decoder of the main model""" | |
| return "decoder" if "decoder_layer" in self.config else "encoder" | |
| def input_layer(self): | |
| if self.input_from == "decoder": | |
| return self.config["decoder_layer"] - 1 | |
| else: | |
| # default using the output from the last encoder layer (-1) | |
| return self.config.get("encoder_layer", 0) - 1 | |
| def loss_weight_schedule(self): | |
| return ( | |
| "decay" | |
| if "loss_weight_max" in self.config | |
| and "loss_weight_decay_steps" in self.config | |
| else "fixed" | |
| ) | |
| def get_loss_weight(self, num_updates): | |
| if self.loss_weight_schedule == "fixed": | |
| weight = self.config.get("loss_weight", 1.0) | |
| else: # "decay" | |
| assert ( | |
| self.config.get("loss_weight_decay_steps", 0) > 0 | |
| ), "loss_weight_decay_steps must be greater than 0 for a decay schedule" | |
| loss_weight_min = self.config.get("loss_weight_min", 0.0001) | |
| loss_weight_decay_stepsize = ( | |
| self.config["loss_weight_max"] - loss_weight_min | |
| ) / self.config["loss_weight_decay_steps"] | |
| weight = max( | |
| self.config["loss_weight_max"] | |
| - loss_weight_decay_stepsize * num_updates, | |
| loss_weight_min, | |
| ) | |
| return weight | |
| def prepend_bos_and_append_tgt_lang_tag(self) -> bool: | |
| """Prepend BOS and append target lang ID token to the target (e.g. mBART with language token pretraining).""" | |
| return self.config.get("prepend_bos_and_append_tgt_lang_tag", False) | |
| def eos_token(self): | |
| """EOS token during generation""" | |
| return self.config.get("eos_token", "<eos>") | |
| def rdrop_alpha(self): | |
| return self.config.get("rdrop_alpha", 0.0) | |
| def is_first_pass_decoder(self): | |
| flag = self.config.get("is_first_pass_decoder", False) | |
| if flag: | |
| if self.decoder_type == "ctc": | |
| raise ValueError( | |
| "First-pass decoder in the multi-decoder model must not be CTC." | |
| ) | |
| if "target" not in self.task_name: | |
| raise Warning( | |
| 'The name of the first-pass decoder does not include "target".' | |
| ) | |
| return flag | |
| def get_lang_tag_mapping(self): | |
| return self.config.get("lang_tag_mapping", {}) | |