| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import os |
| import sys |
| import torch |
|
|
| from argparse import Namespace |
| from dataclasses import dataclass, field |
| from typing import Optional, Any |
| from omegaconf import MISSING, II, OmegaConf, dictconfig |
|
|
| from fairseq.data import ( |
| AddTargetDataset, |
| ConcatDataset, |
| Dictionary, |
| FileAudioDataset, |
| ResamplingDataset, |
| encoders, |
| ) |
| from fairseq.dataclass import FairseqDataclass |
| from fairseq.dataclass.configs import GenerationConfig |
| from fairseq.dataclass.utils import convert_namespace_to_omegaconf |
|
|
| from . import FairseqTask, register_task |
| from .. import utils |
| from ..logging import metrics |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| class LabelEncoder(object): |
| def __init__(self, dictionary): |
| self.dictionary = dictionary |
|
|
| def __call__(self, label): |
| return self.dictionary.encode_line( |
| label, append_eos=False, add_if_not_exist=False |
| ) |
|
|
| class IDEncoder(object): |
| def __call__(self, ids): |
| ids = ids.split() |
| ids = list(map(int, ids)) |
| idx = torch.IntTensor(ids) |
| return idx |
|
|
|
|
| @dataclass |
| class InferredW2vConfig: |
| |
| |
| mask_length: Optional[int] = II("model.mask_length") |
| mask_prob: Optional[float] = II("model.mask_prob") |
| mask_selection: Optional[str] = II("model.mask_selection") |
| mask_other: Optional[float] = II("model.mask_other") |
| no_mask_overlap: Optional[bool] = II("model.no_mask_overlap") |
| mask_min_space: Optional[int] = II("model.mask_min_space") |
| mask_channel_length: Optional[int] = II("model.mask_channel_length") |
| mask_channel_prob: Optional[float] = II("model.mask_channel_prob") |
| mask_channel_selection: Optional[str] = II("model.mask_channel_selection") |
| mask_channel_other: Optional[float] = II("model.mask_channel_other") |
| no_mask_channel_overlap: Optional[bool] = II("model.no_mask_channel_overlap") |
| mask_channel_min_space: Optional[int] = II("model.mask_channel_min_space") |
|
|
| conv_feature_layers: Optional[str] = II("model.conv_feature_layers") |
| encoder_embed_dim: Optional[int] = II("model.encoder_embed_dim") |
|
|
|
|
| @dataclass |
| class AudioPretrainingConfig(FairseqDataclass): |
| data: Optional[str] = field(default=MISSING, metadata={"help": "path to data directory"}) |
| train_path: Optional[str] = field(default=MISSING) |
|
|
| labels: Optional[str] = field( |
| default=None, |
| metadata={"help": "extension of the label file to load, used for fine-tuning"}, |
| ) |
| dict_path: Optional[str] = field(default=MISSING) |
| dict_model: Optional[str] = field(default=None) |
| binarized_dataset: bool = field( |
| default=False, |
| metadata={ |
| "help": "if true, loads binarized dataset (useful for very large datasets). " |
| "See examples/wav2vec/scripts/binarize_manifest.sh" |
| }, |
| ) |
| sample_rate: int = field( |
| default=16_000, |
| metadata={ |
| "help": "target sample rate. audio files will be up/down sampled to this rate" |
| }, |
| ) |
| normalize: bool = field( |
| default=False, |
| metadata={"help": "if set, normalizes input to have 0 mean and unit variance"}, |
| ) |
| enable_padding: bool = field( |
| default=False, metadata={"help": "pad shorter samples instead of cropping"} |
| ) |
| max_sample_size: Optional[int] = field( |
| default=None, metadata={"help": "max sample size to crop to for batching"} |
| ) |
| min_sample_size: Optional[int] = field( |
| default=None, metadata={"help": "min sample size to skip small examples"} |
| ) |
| multilang_sampling_alpha: Optional[float] = field(default=0.5) |
|
|
| |
| |
| eval_wer: bool = field( |
| default=False, metadata={"help": "compute WER for Seq2Seq models"} |
| ) |
| eval_wer_config: GenerationConfig = field( |
| default_factory=lambda: GenerationConfig(), |
| metadata={"help": "beam search config for evaluating wer during training"}, |
| ) |
| eval_wer_tokenizer: Any = field( |
| default=None, |
| metadata={"help": "tokenizer config for evaluating wer during training"}, |
| ) |
| eval_wer_post_process: str = field( |
| default="letter", |
| metadata={ |
| "help": "remove BPE tokens before scoring (can be sentencepiece, letter, and more)" |
| }, |
| ) |
| autoregressive: bool = field( |
| default=False, |
| metadata={ |
| "help": "required for autoregressive decoders (like seq2seq models); " |
| "adds 'prev_output_tokens' to input and appends eos to target" |
| }, |
| ) |
| num_batch_buckets: int = field( |
| default=0, |
| metadata={"help": "number of buckets"}, |
| ) |
| precompute_mask_indices: bool = field( |
| default=False, |
| metadata={ |
| "help": "flag to compute mask indices in data preparation.", |
| }, |
| ) |
|
|
| inferred_w2v_config: Optional[InferredW2vConfig] = field( |
| default=None, |
| metadata={ |
| "help": "wav2vec 2.0 masking arguments used to pre-compute masks (required for TPU)", |
| }, |
| ) |
|
|
| tpu: bool = II("common.tpu") |
| max_tokens: Optional[int] = II("common.max_tokens") |
|
|
|
|
| @register_task("audio_pretraining", dataclass=AudioPretrainingConfig) |
| class AudioPretrainingTask(FairseqTask): |
| """ """ |
|
|
| cfg: AudioPretrainingConfig |
|
|
| def __init__( |
| self, |
| cfg: AudioPretrainingConfig, |
| ): |
| super().__init__(cfg) |
| if cfg.eval_wer: |
| assert cfg.labels is not None, "eval_wer can only be set during fine-tuning" |
| self.blank_symbol = "<s>" |
|
|
| self.state.add_factory("target_dictionary", self.load_target_dictionary) |
|
|
| @classmethod |
| def setup_task(cls, cfg: AudioPretrainingConfig, **kwargs): |
| """Setup the task (e.g., load dictionaries). |
| |
| Args: |
| cfg (AudioPretrainingConfig): configuration of this task |
| """ |
|
|
| return cls(cfg) |
|
|
| def load_target_dictionary(self): |
| if self.cfg.labels: |
| if self.cfg.dict_path is not None: |
| |
| target_dictionary = Dictionary(self.cfg.dict_path, self.cfg.dict_model) |
| else: |
| dict_path = os.path.join(self.cfg.data, f"vocab.json") |
| target_dictionary = Dictionary(dict_path) |
| return target_dictionary |
| return None |
|
|
| def _get_mask_precompute_kwargs(self, cfg): |
| if self.cfg.precompute_mask_indices or self.cfg.tpu: |
| assert ( |
| cfg.inferred_w2v_config is not None |
| ), "inferred_w2v_config must be set" |
| return OmegaConf.to_container( |
| cfg.inferred_w2v_config, resolve=True, enum_to_str=True |
| ) |
| else: |
| return {} |
|
|
| def _get_sample_prob(self, dataset_lens): |
| """ |
| Get smoothed sampling porbability by languages. This helps low resource |
| languages by upsampling them. |
| """ |
| prob = dataset_lens / dataset_lens.sum() |
| smoothed_prob = prob ** self.args.multilang_sampling_alpha |
| smoothed_prob = smoothed_prob / smoothed_prob.sum() |
| return smoothed_prob |
|
|
| def load_dataset(self, split: str, task_cfg: FairseqDataclass = None, **kwargs): |
| data_path = self.cfg.data |
| task_cfg = task_cfg or self.cfg |
|
|
| |
| if isinstance(task_cfg, Namespace): |
| if not hasattr(task_cfg, "autoregressive"): |
| task_cfg.autoregressive = not task_cfg.criterion == "ctc" |
|
|
| manifest_list = split.split(",") |
| datasets = [] |
| datasets_lengths = [] |
|
|
| for f in manifest_list: |
| manifest = os.path.join(self.cfg.data, "{}.tsv".format(f)) |
| dataset = FileAudioDataset( |
| manifest, |
| sample_rate=task_cfg.sample_rate, |
| max_sample_size=self.cfg.max_sample_size, |
| min_sample_size=self.cfg.min_sample_size, |
| pad=task_cfg.labels is not None or task_cfg.enable_padding, |
| normalize=task_cfg.normalize, |
| num_buckets=self.cfg.num_batch_buckets or int(self.cfg.tpu), |
| compute_mask_indices=(self.cfg.precompute_mask_indices or self.cfg.tpu), |
| **self._get_mask_precompute_kwargs(task_cfg), |
| ) |
|
|
| if task_cfg.labels: |
| label_path = os.path.join(data_path, "{}.idx".format(f)) |
| labels = [] |
| count = 0 |
| itr = 0 |
| if not os.path.exists(label_path): |
| label_path = os.path.join(data_path, "{}.{}".format(f, task_cfg.labels)) |
| if not os.path.exists(label_path): |
| label_path = os.path.join(data_path, "{}.id".format(f)) |
| with open(label_path, "r") as f: |
| for line in f: |
| count += 1 |
| if itr < len(dataset.skipped) and count == dataset.skipped[itr]: |
| itr += 1 |
| continue |
| line = line.strip() |
| labels.append(line) |
|
|
| if ".id" in label_path: |
| process_label = IDEncoder() |
| else: |
| process_label = LabelEncoder(self.target_dictionary) |
| assert len(labels) == len(dataset) |
|
|
| dataset = AddTargetDataset( |
| dataset, |
| labels, |
| pad=self.target_dictionary.pad(), |
| eos=self.target_dictionary.eos(), |
| batch_targets=True, |
| process_label=process_label, |
| add_to_input=getattr(task_cfg, "autoregressive", False), |
| ) |
|
|
| if len(manifest_list) == 1: |
| self.datasets[split] = dataset |
| else: |
| languages = [manifest.split('/')[0] for manifest in manifest_list] |
| datasets_lengths = np.array(datasets_lengths) |
| sample_probs = self._get_sample_prob(datasets_lengths) |
| for id, lang in enumerate(languages): |
| logger.info( |
| "Sample probability by language: {} : {:.5f}".format(lang, sample_probs[id]) |
| ) |
| size_ratio = (sample_probs * datasets_lengths.sum()) / datasets_lengths |
| for id, lang in enumerate(languages): |
| logger.info( |
| "Up/Down Sampling ratio by language: {} : {:.2f}".format(lang, size_ratio[id]) |
| ) |
|
|
| resampled_lang_datasets = [ |
| ResamplingDataset( |
| datasets[i], |
| size_ratio=size_ratio[i], |
| seed=task_cfg.seed, |
| epoch=epoch, |
| replace=size_ratio[i] >= 1.0, |
| ) |
| for i, d in enumerate(datasets) |
| ] |
| self.datasets[split] = ConcatDataset(resampled_lang_datasets) |
|
|
| @property |
| def source_dictionary(self): |
| return None |
|
|
| @property |
| def target_dictionary(self): |
| """Return the :class:`~fairseq.data.Dictionary` for the language |
| model.""" |
| return self.state.target_dictionary |
|
|
| def max_positions(self): |
| """Maximum input length supported by the encoder.""" |
| return (sys.maxsize, sys.maxsize) |
|
|
| def filter_indices_by_size( |
| self, |
| indices, |
| dataset, |
| max_positions=None, |
| ignore_invalid_inputs=False, |
| ): |
| |
| return indices |
|
|
| def valid_step(self, sample, model, criterion): |
| loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
| if self.cfg.eval_wer and self.cfg.autoregressive: |
| metrics = self._inference_with_wer(self.sequence_generator, sample, model) |
| logging_output["_num_char_errors"] = metrics["num_char_errors"] |
| logging_output["_num_chars"] = metrics["num_chars"] |
| logging_output["_num_word_errors"] = metrics["num_word_errors"] |
| logging_output["_num_words"] = metrics["num_words"] |
| return loss, sample_size, logging_output |
|
|
| def build_model(self, model_cfg: FairseqDataclass): |
| model = super().build_model(model_cfg) |
|
|
| if self.cfg.eval_wer and self.cfg.autoregressive: |
| self.sequence_generator = self.build_generator( |
| [model], |
| self.cfg.eval_wer_config, |
| ) |
| if self.cfg.eval_wer_tokenizer: |
| self.tokenizer = encoders.build_tokenizer(self.cfg.eval_wer_tokenizer) |
| else: |
| self.tokenizer = None |
|
|
| actualized_cfg = getattr(model, "cfg", None) |
|
|
| return model |
|
|
| def _inference_with_wer(self, generator, sample, model): |
| import editdistance |
|
|
| def decode(toks): |
| s = self.target_dictionary.string( |
| toks.int().cpu(), |
| self.cfg.eval_wer_post_process, |
| escape_unk=True, |
| ) |
| if self.tokenizer: |
| s = self.tokenizer.decode(s) |
| return s |
|
|
| num_word_errors, num_char_errors = 0, 0 |
| num_chars, num_words = 0, 0 |
| gen_out = self.inference_step(generator, [model], sample, None) |
| for i in range(len(gen_out)): |
| hyp = decode(gen_out[i][0]["tokens"]) |
| ref = decode( |
| utils.strip_pad(sample["target"][i], self.target_dictionary.pad()), |
| ) |
| num_char_errors += editdistance.eval(hyp, ref) |
| num_chars += len(ref) |
| hyp_words = hyp.split() |
| ref_words = ref.split() |
| num_word_errors += editdistance.eval(hyp_words, ref_words) |
| num_words += len(ref_words) |
|
|
| return { |
| "num_char_errors": num_char_errors, |
| "num_chars": num_chars, |
| "num_word_errors": num_word_errors, |
| "num_words": num_words, |
| } |
|
|
| def reduce_metrics(self, logging_outputs, criterion): |
| super().reduce_metrics(logging_outputs, criterion) |
|
|
| zero = torch.scalar_tensor(0.0) |
| num_char_errors = sum( |
| log.get("_num_char_errors", zero) for log in logging_outputs |
| ) |
| num_chars = sum(log.get("_num_chars", zero) for log in logging_outputs) |
| num_word_errors = sum( |
| log.get("_num_word_errors", zero) for log in logging_outputs |
| ) |
| num_words = sum(log.get("_num_words", zero) for log in logging_outputs) |
| metrics.log_scalar("_num_char_errors", num_char_errors) |
| metrics.log_scalar("_num_chars", num_chars) |
| metrics.log_scalar("_num_word_errors", num_word_errors) |
| metrics.log_scalar("_num_words", num_words) |
| if num_chars > 0: |
| metrics.log_derived( |
| "uer", |
| lambda meters: meters["_num_char_errors"].sum |
| * 100.0 |
| / meters["_num_chars"].sum |
| if meters["_num_chars"].sum > 0 |
| else float("nan"), |
| ) |
| if num_words > 0: |
| metrics.log_derived( |
| "wer", |
| lambda meters: meters["_num_word_errors"].sum |
| * 100.0 |
| / meters["_num_words"].sum |
| if meters["_num_words"].sum > 0 |
| else float("nan"), |
| ) |
|
|