diff --git a/fairseq/examples/hubert/tests/sample.xlarge.L30.npy b/fairseq/examples/hubert/tests/sample.xlarge.L30.npy new file mode 100644 index 0000000000000000000000000000000000000000..b617c3876654a69d73e3f9bd2ad7ff01b764cf24 --- /dev/null +++ b/fairseq/examples/hubert/tests/sample.xlarge.L30.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bbe9f0929ecd4c58786be55215d83f85787ff0d81196bc2e73414f82a8939806 +size 3051712 diff --git a/fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md b/fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md new file mode 100644 index 0000000000000000000000000000000000000000..52c528fa1e40e4af290c486bd22d59dbc7aabcea --- /dev/null +++ b/fairseq/examples/textless_nlp/dgslm/hubert_fisher/README.md @@ -0,0 +1,47 @@ +# Dialogue Speech-to-Unit Encoder for dGSLM: The Fisher HuBERT model +For the speech2unit encoder, we train a [HuBERT model](https://arxiv.org/pdf/2106.07447.pdf) on the [Fisher dataset](http://www.lrec-conf.org/proceedings/lrec2004/pdf/767.pdf) for 3 iterations (see [our paper](https://arxiv.org/pdf/2203.16502.pdf) for more details) and train a k-means model with 500 units on the layer 12 features of the HuBERT model. + +## Model checkpoints +The pre-trained HuBERT and k-means model checkpoints can be found here: + +| Fisher HuBERT model | k-means model | +|---------------------|---------------| +|[download](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hubert/hubert_fisher.pt)|[download](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hubert/hubert_fisher_km_500.bin)| + + +## Encode audio to discrete units +Below is an example command to encode a stereo dataset to discrete units using the pre-trained model checkpoints : +```bash +for CHANNEL_ID in 1 2; do + python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \ + --feature_type hubert \ + --kmeans_model_path path/to/hubert_fisher_km_500.bin \ + --acoustic_model_path path/to/hubert_fisher.pt \ + --layer 12 \ + --manifest_path $MANIFEST_FILE \ + --out_quantized_file_path ${OUTPUT_FILE}-channel${CHANNEL_ID} \ + --extension $EXTENSION \ + --channel_id $CHANNEL_ID +done +``` +where MANIFEST_FILE is the output of [wav2vec manifest script](https://github.com/facebookresearch/fairseq/blob/main/examples/wav2vec/wav2vec_manifest.py), which can be obtained through the following command : +``` +python examples/wav2vec/wav2vec_manifest.py --valid-percent=0.0 $AUDIO_DIR --dest=$OUTPUT_DIR --ext=$EXTENSION +``` + +Otherwise, you can encode an audio file in python interactively with the HubertTokenizer class : +```python +# Load the Hubert tokenizer +from examples.textless_nlp.dgslm.dgslm_utils import HubertTokenizer +encoder = HubertTokenizer( + hubert_path = "/path/to/hubert_ckpt.pt", + hubert_layer = 12, + km_path = "path/to/km.bin" +) + +# Encode the audio to units +path = "/path/to/stereo/audio.wav" +codes = encoder.wav2codes(path) +# > ['7 376 376 133 178 486 486 486 486 486 486 486 486 2 486', +# > '7 499 415 177 7 7 7 7 7 7 136 136 289 289 408'] +``` \ No newline at end of file diff --git a/fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md b/fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..5d4a59a9acfaaeae616354e25ce031c387857909 --- /dev/null +++ b/fairseq/examples/textless_nlp/dgslm/vocoder_hifigan/README.md @@ -0,0 +1,47 @@ +# Dialogue Unit-to-Speech Decoder for dGSLM +For the unit2speech decoder, we train a [discrete unit-based HiFi-GAN vocoder](https://arxiv.org/pdf/2104.00355.pdf) on the [Fisher dataset](http://www.lrec-conf.org/proceedings/lrec2004/pdf/767.pdf). + +## Model checkpoint +The pre-trained model checkpoint can be found here : + +| HiFi-GAN vocoder based on HuBERT Fisher Units | +|-----------------------------------------------| +|[model checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hifigan/hifigan_vocoder) - [config](https://dl.fbaipublicfiles.com/textless_nlp/dgslm/checkpoints/hifigan/config.json) | + +## Decode discrete units to audio +To create waveform from discrete units, use the script `generate_stereo_waveform.py` : +```bash +python examples/textless_nlp/dgslm/vocoder_hifigan/generate_stereo_waveform.py \ + --in-file $INPUT_CODE_FILE \ + --vocoder $VOCODER_PATH \ + --vocoder-cfg $VOCODER_CONFIG \ + --results-path $OUTPUT_DIR +``` +where INPUT_CODE_FILE is expected to have the following format : +``` +{'audio': 'file_1', 'unitA': '8 8 ... 352 352', 'unitB': '217 8 ... 8 8'} +{'audio': 'file_2', 'unitA': '5 5 ... 65 65', 'unitB': '6 35 ... 8 9'} +... +``` + +You can also use the HifiganVocoder class to generate waveform from the codes interactively : +```python +# Load the Hifigan vocoder +from examples.textless_nlp.dgslm.dgslm_utils import HifiganVocoder +decoder = HifiganVocoder( + vocoder_path = "/path/to/hifigan_vocoder", + vocoder_cfg_path = "/path/to/config.json", +) + +# Decode the units to waveform +codes = [ + '7 376 376 133 178 486 486 486 486 486 486 486 486 2 486', + '7 499 415 177 7 7 7 7 7 7 136 136 289 289 408', +] +wav = decoder.codes2wav(codes) +# > array of shape (2, 4800) + +# Play the waveform +import IPython.display as ipd +ipd.Audio(wav, rate=16_000) +``` diff --git a/fairseq/examples/textless_nlp/gslm/README.md b/fairseq/examples/textless_nlp/gslm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7a76ffd57c066c20af94aa3fca24c18e2ba4c3dd --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/README.md @@ -0,0 +1,21 @@ +# Generative Spoken Language Modeling + +* [Paper](https://arxiv.org/abs/2102.01192) +* [Demo](https://speechbot.github.io/gslm/index.html) + +We build and evaluate generative speech2speech systems using [Log Mel Filtebank](https://pytorch.org/audio/stable/compliance.kaldi.html#fbank), [Modified CPC](https://github.com/facebookresearch/CPC_audio), [HuBERT Base](https://github.com/pytorch/fairseq/tree/main/examples/hubert) and [Wav2Vec 2.0 Large](https://github.com/pytorch/fairseq/tree/main/examples/wav2vec). Our system is composed of three components, namely, *speech2unit*, *ulm* and *unit2speech*. We explain about models and usage of these components in their respective sub-directories. See the links below. + +## Speech to Unit Model (speech2unit) +Speech to unit model is used for quantizing raw speech into learned discrete speech units. [More details](speech2unit) + +## Unit Language Model (ulm) +Unit Language Model is a generative language model trained on discrete speech units. [More details](ulm) + +## Unit to Speech Model (unit2speech) +Unit to speech model is used for synthesizing speech from discrete speech units. [More details](unit2speech) + +## Metrics +We show how to compute ASR based metrics as well as zero-shot metrics proposed in our paper [here](metrics). + +## Tools +We share two tools to resynthesize a given spoken utterance, and generate novel spoken language given a spoken prompt. [More detail](tools) diff --git a/fairseq/examples/textless_nlp/gslm/metrics/README.md b/fairseq/examples/textless_nlp/gslm/metrics/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0a63e2f0d844ce157f9502c82738aac2a0de3f0c --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/metrics/README.md @@ -0,0 +1,10 @@ +# GSLM Metrics + +## ASR Metrics +The suite of metrics here uses an ASR model to transcribe the synthesized speech into text, and then uses text-based metrics. We also use word error rate from ASR transcription itself as one of the metrics. [More details](asr_metrics) + +## ABX Metrics +We use [ABX](https://www.semanticscholar.org/paper/ABX-Discriminability-Measures-and-Applications-Schatz/13d3537228f728c1063cc83743cb118bba3367a0) to evaluate how well-separated phonetic categories are with quantized representations. [More details](abx_metrics) + +## sWUGGY and sBLIMP +We refer to [ZeroSpeech challenge](https://www.zerospeech.com/2021/track_s.html#scoring-based-metrics) for details on the sWUGGY and sBLIMP metrics. diff --git a/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py b/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..41cf558970608fa5a9241e91e59ba214b609dc73 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/metrics/abx_metrics/dump_abx_feats.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +import joblib +import numpy as np + +from examples.textless_nlp.gslm.speech2unit.clustering.utils import get_audio_files +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_features + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + +def get_parser(): + parser = argparse.ArgumentParser( + description="Quantize using K-means clustering over acoustic features." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + required=True, + help="Acoustic feature type", + ) + parser.add_argument( + "--kmeans_model_path", + type=str, + required=True, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Pretrained model checkpoint", + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--out_dir_path", + required=True, + type=str, + help="File path of quantized output.", + ) + parser.add_argument( + "--extension", type=str, default=".flac", help="Features file path" + ) + return parser + + +def one_hot(feat, n_clusters): + return np.eye(n_clusters)[feat] + +def main(args, logger): + # Feature extraction + logger.info(f"Extracting {args.feature_type} acoustic features...") + features_batch = get_features( + feature_type=args.feature_type, + checkpoint_path=args.checkpoint_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=1.0, + flatten=False, + ) + logger.info(f"Features extracted for {len(features_batch)} utterances.\n") + logger.info(f"Dimensionality of representation = {features_batch[0].shape[1]}") + + logger.info(f"Loading K-means model from {args.kmeans_model_path} ...") + kmeans_model = joblib.load(open(args.kmeans_model_path, "rb")) + kmeans_model.verbose = False + + _, fnames, _ = get_audio_files(args.manifest_path) + + os.makedirs(args.out_dir_path, exist_ok=True) + logger.info(f"Writing quantized features to {args.out_dir_path}") + for i, feats in enumerate(features_batch): + pred = kmeans_model.predict(feats) + emb = one_hot(pred, kmeans_model.n_clusters) + base_fname = os.path.basename(fnames[i]).rstrip(args.extension) + output_path = os.path.join(args.out_dir_path, f"{base_fname}.npy") + with open(output_path, "wb") as f: + np.save(f, emb) + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md new file mode 100644 index 0000000000000000000000000000000000000000..90741f42b0b070f2a91b63c8badb817c6aa24230 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/README.md @@ -0,0 +1,87 @@ +# ASR-based evaluation + +Overall, the life cycle of the ASR-based evaluation for an ULM contains the following steps: + 1. Training an ULM and sampling from it [[description]](./../../ulm) + 2. Running UTS on the sampled unit sequences [[description]](./../../unit2speech) + 3. Pre-processing for the ASR (down-sampling to 16 KHz, aligning length of the generated audio with ground-truth utterances) + 4. Running ASR + 5. Calculation of the post-ASR evaluation metrics + +Here we assume that you have already went throught the first two steps and focus on the rest. + +## Preprocessing +### Down-sampling to 16KHz +The bulk conversion can be done by running +```bash + python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py $UTS_OUTPUT $UTS_OUTPUT_DOWNSAMPLE + ``` + where `$UTS_OUTPUT` specifies the directory with the generated audio and `$UTS_OUTPUT_DOWNSAMPLE` is the directory where downsampled audio would be saved. + + ### Matching by length +This step is somewhat optional. However, if you want to compare the fluency and diversity of a generated speech utterance to that of the ground-truth speech with the same prefix, it is a good idea to force them to be of the same length. +```bash +python $FAIRSEQ_ROOT/examples/textless_nlp/asr_metrics/cut_as.py \ + --samples_dir=$UTS_OUTPUT_DOWNSAMPLE --out_dir=$UTS_OUTPUT_DOWNSAMPLE_CUT \ + --prompts_description=data/ground_truth_continuation_dev.json +``` + +Here `ground_truth_continuation_dev.json` is a json file with ground-truth text from LibriSpeech dev-clean, associated with some meta-data (assuming the evaluation is done on dev-clean). This file can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_dev.json). A similar file for the test-clean is [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/ground_truth_continuation_test.json). These files are used for the evaluation and contain texts for audio sequences that are at least 6s long. + +## Running ASR +We use a pre-trained wav2vec model to run the ASR step. We firstly need to prepare manifest files which, roughly, tell the ASR system which files we want to transcribe. You can find more details and download the `960h_scratch.pt` checkpoint +[[here]](https://github.com/pytorch/fairseq/blob/main/examples/wav2vec/README.md)). To run ASR, you would also need to +install KenLM, Flashlight decoder, and download the KenLM 4-gram English language model. + +```bash + python $FAIRSEQ_ROOT/examples/wav2vec/wav2vec_manifest.py \ + $UTS_OUTPUT_DOWNSAMPLE_CUT --valid-percent 0.0 --dest $MANIFEST_DIR --ext wav +``` +where `$UTS_OUTPUT_DOWNSAMPLE_CUT` speficies the directory with the preprocessed UTS outputs and `$MANIFEST_DIR` is the output directory. + +We will be running an out-of-the-box evaluation script which requires ground-truth transcripts to measure quality metrics. We are only +interested in the transcripts (and we don't have ground-truth outputs for when our ULM generated!), hence we will just generate +some dummy transcripts instead: +```bash +cp $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dict.ltr.txt $MANIFEST_DIR +python $FAIRSEQ_ROOT/examples/textless_nlp/gslm/asr_metrics/misc/dummy_asr_data.py --tsv=$MANIFEST_DIR/train.tsv \ + --output-dir=$MANIFEST_DIR +``` + +Now we are ready for running ASR: +``` +mkdir -p asr +python $FAIRSEQ_ROOT/examples/speech_recognition/infer.py \ + $MANIFEST_DIR \ + --task audio_pretraining --nbest 1 --path 960h_scratch.pt \ + --gen-subset=train --results-path $PATH_TO_ASR_OUTPUT \ + --w2l-decoder kenlm --lm-model 4-gram.bin \ + --lexicon librispeech/lexicon_ltr.lst --word-score -1 \ + --sil-weight 0 --lm-weight 2 --criterion ctc --labels ltr --max-tokens 300000 --remove-bpe letter +``` +where `lexicon_ltr.lst` is the LibriSpeech lexicon and `$PATH_TO_ASR_OUTPUT` is the output directory (can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/lexicon_ltr.lst)). + +## Evaluation metrics +We run evaluation on the 1_000 shortest sequences that are at least 6s long. To filter those from the ASR transcript, we additionally provide each metric script with the paths to the manifest and `ground_truth_continuation_*` files. + +### Perplexity (PPX) +To get a PPX metric estimate on an ASR transcript, you need to run the following command: +```bash +python ppx.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail\ + --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json +``` +where `--cut-tail` tells the script to ignore the last token on each line (ASR puts the sequence ID there). + +### Self- and Auto-BLEU +```bash +python self_bleu.py $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt --cut-tail \ + --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json +``` + +### Continuation-BLEU +```bash +python continuation_eval.py --asr-transcript $PATH_TO_ASR_OUTPUT/hypo.word-960h_scratch.pt-train.txt \ + --manifest=$MANIFEST_DIR/train.tsv --prompts-description=data/ground_truth_continuation_dev.json +``` + +### AUC +Based on the metrics calculated above, we can estimate the AUC of the perplexity/diversity trade-off. We provide an illustration in a [Colab notebook](https://colab.research.google.com/drive/1pVPfOVax_PU3MkYdHRSsa-SI8GBUldNt?usp=sharing). diff --git a/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py new file mode 100644 index 0000000000000000000000000000000000000000..062bb82f669f63a537b6ee8df4d42d292eb2575e --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/metrics/asr_metrics/self_auto_bleu.py @@ -0,0 +1,201 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import nltk +from misc.bleu_utils import sentence_bleu +import warnings + + +def get_target_sequences(manifest, ground_truth, to_take=1000): + import json + import pathlib + + with open(ground_truth, 'r') as fin: + original_continuations = json.loads(fin.read()) + + sequence2length = [(k, v[0]) for k, v in original_continuations.items()] + assert all(float(v) >= 6.0 for (_, v) in sequence2length) # 6 seconds + + sequence2length.sort(key=lambda x: x[1]) + to_take_sequences = set(v[0] for v in sequence2length[:to_take]) + to_take_ids = [] + + with open(manifest, 'r') as f: + f.readline() + + for i, line in enumerate(f.readlines()): + seq_id = line.split()[0] + seq_id = pathlib.Path(seq_id).name.split('__')[0] + + if seq_id in to_take_sequences: + to_take_ids.append(i) + + print(f'Took {len(to_take_ids)} ids') + return set(to_take_ids) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--asr-transcript', type=str, + help='Path to the transcript file.') + + parser.add_argument('--manifest', required=True) + parser.add_argument('--prompts-description', required=True) + + parser.add_argument('--cut-id', action='store_true', + help='Whether cut the first token (typically a seq id)') + parser.add_argument('--cut-tail', action='store_true', + help='Whether cut the last token (typically a speaker id)') + parser.add_argument('--debug', action='store_true') + + args = parser.parse_args() + + return args + + +def get_self_bleu(utterances, averaging_mode, weights): + self_bleu = [] + + for i in range(len(utterances)): + hypo = utterances[i] + rest = utterances[:i] + utterances[i+1:] + + self_bleu.append(sentence_bleu(rest, hypo, weights, + no_length_penalty=True, averaging_mode=averaging_mode)) + + return self_bleu + + +def get_self_bleu2_arithmetic(utterances): + weights = (0.5, 0.5) # equal weight for unigrams and bigrams + return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights) + + +def get_self_bleu2_geometric(utterances): + weights = (0.5, 0.5) + return get_self_bleu(utterances, averaging_mode='geometric', weights=weights) + + +def get_auto_bleu2_arithmetic(utterances): + weights = (0.5, 0.5) + return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances] + + +def get_auto_bleu2_geometric(utterances): + weights = (0.5, 0.5) + return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances] + + +def get_auto_bleu3_geometric(utterances): + weights = (1./3, 1./3, 1./3) + return [auto_bleu(u, mean_mode='geometric', weights=weights) for u in utterances] + + +def get_auto_bleu3_arithmetic(utterances): + weights = (1./3, 1./3, 1./3) + return [auto_bleu(u, mean_mode='arithmetic', weights=weights) for u in utterances] + + +def get_self_bleu3_arithmetic(utterances): + weights = (1./3, 1./3, 1./3) + return get_self_bleu(utterances, averaging_mode='arithmetic', weights=weights) + + +def get_self_bleu3_geometric(utterances): + weights = (1./3, 1./3, 1./3) + return get_self_bleu(utterances, averaging_mode='geometric', weights=weights) + + +def auto_bleu(sentence, weights, mean_mode='arithmetic'): + if len(sentence) <= 1: + return 0 + + N = len(weights) + + bleu_n = np.zeros([N]) + for n in range(N): + targ_ngrams = list(nltk.ngrams(sentence, n+1)) + for p in range(len(targ_ngrams)): + left = sentence[:p] + right = sentence[(p+n+1):] + rest_ngrams = list(nltk.ngrams(left, n+1)) + \ + list(nltk.ngrams(right, n+1)) + # compute the nb of matching ngrams + bleu_n[n] += targ_ngrams[p] in rest_ngrams + bleu_n[n] /= len(targ_ngrams) # average them to get a proportion + + weights = np.array(weights) + if mean_mode == 'arithmetic': + return (bleu_n * weights).sum() + elif mean_mode == 'geometric': + return (bleu_n ** weights).prod() + else: + raise ValueError(f'Unknown agggregation mode {mean_mode}') + + +def main(): + from multiprocessing import Pool + + args = get_args() + target_ids = get_target_sequences(args.manifest, args.prompts_description) + + with open(args.asr_transcript, 'r') as fin: + lines = fin.readlines() + + terms = [x.strip().split() for x in lines] + filtered = [] + for term in terms: + line_id = int(term[-1].split('-')[1][:-1]) + if line_id in target_ids: + filtered.append(term) + terms = filtered + + if args.cut_id: + terms = [x[1:] for x in terms] + if args.cut_tail: + terms = [x[:-1] for x in terms] + + if args.debug: + terms = terms[:10] + + tasks = [ + ('Self-BLEU2-arithmetic', get_self_bleu2_arithmetic), + ('Self-BLEU2-geometric', get_self_bleu2_geometric), + ('Auto-BLEU2-arithmetic', get_auto_bleu2_arithmetic), + ('Auto-BLEU2-geometric', get_auto_bleu2_geometric), + + ('Self-BLEU3-arithmetic', get_self_bleu3_arithmetic), + ('Self-BLEU3-geometric', get_self_bleu3_geometric), + ('Auto-BLEU3-arithmetic', get_auto_bleu3_arithmetic), + ('Auto-BLEU3-geometric', get_auto_bleu3_geometric), + ] + + n_processes = min(16, len(tasks)) + with Pool(n_processes) as pool: + metrics = pool.map(run_f, [(t[1], terms) for t in tasks]) + + for (metric_name, _), metric in zip(tasks, metrics): + metric, sem = np.mean(metric), np.std(metric) / np.sqrt(len(metric)) + + metric, sem = [ + round(100 * x, 2) for x in [metric, sem] + ] + + print(f'{metric_name} {metric} +- {sem}') + + +def run_f(task_params): + f, terms = task_params + return f(terms) + + +if __name__ == '__main__': + # NLTK produces warnings + warnings.filterwarnings("ignore") + + main() diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/README.md b/fairseq/examples/textless_nlp/gslm/speech2unit/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9dff9d33ac6cc1a7e2c6c7c653d1265123174abd --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/README.md @@ -0,0 +1,68 @@ +# Speech to Unit Model (speech2unit) + +## Acoustic Model +For quantizing speech we learn a K-means clustering over acoustic representations for which we either use Log-Mel Filterbank or pretrained acoustic representation models. For using pretrained models, please download from their respective locations linked below. +* [Modified CPC](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/cpc_big_ll6kh_top_ctc.pt) +* [HuBERT-Base](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) +* [Wav2Vec 2.0-Base](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) + +## Quantization Model +You can download pretrained quantized model from the list below. + +K-Means Model | Download Link +|-|- +Log Mel Filterbank + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km50/km.bin) +Log Mel Filterbank + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km100/km.bin) +Log Mel Filterbank + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/km200/km.bin) +Modified CPC + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km50/km.bin) +Modified CPC + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km100/km.bin) +Modified CPC + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/km200/km.bin) +HuBERT Base + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km50/km.bin) +HuBERT Base + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin) +HuBERT Base + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km200/km.bin) +wav2vec 2.0 Large + KM50 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km50/km.bin) +wav2vec 2.0 Large + KM100 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km100/km.bin) +wav2vec 2.0 Large + KM200 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/km200/km.bin) + +### Quantization +For quantizing speech with a given acoustic representation, please follow the steps below. +1. Learn K-means clustering model +``` +N_CLUSTERS= +TYPE= +CKPT_PATH= +LAYER= +MANIFEST= +KM_MODEL_PATH= + +PYTHONPATH=. python examples/textless_nlp/gslm/speech2unit/clustering/cluster_kmeans.py \ + --num_clusters $N_CLUSTERS \ + --feature_type $TYPE \ + --checkpoint_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_kmeans_model_path $KM_MODEL_PATH +``` +2. Quantize using the learned clusters +``` +MANIFEST= +OUT_QUANTIZED_FILE= + +python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \ + --feature_type $TYPE \ + --kmeans_model_path $KM_MODEL_PATH \ + --acoustic_model_path $CKPT_PATH \ + --layer $LAYER \ + --manifest_path $MANIFEST \ + --out_quantized_file_path $OUT_QUANTIZED_FILE \ + --extension ".flac" +``` + +Note about the manifest file is a file with paths and length of input audio files. The format of the file is as follows: +``` + +\t +\t +... +``` + diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py new file mode 100644 index 0000000000000000000000000000000000000000..031567c6d85d16b5236053abf008b7cabccb4673 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/dump_feats.py @@ -0,0 +1,91 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging + +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import ( + get_and_dump_features, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Compute and dump log mel fbank features." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + help="Acoustic feature type", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--out_features_path", + type=str, + default=None, + help="Features file path to write to", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Pretrained acoustic model checkpoint", + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--sample_pct", + type=float, + help="Percent data to use for K-means training", + default=0.1, + ) + parser.add_argument( + "--out_features_path", + type=str, + help="Path to save log mel fbank features", + ) + return parser + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +if __name__ == "__main__": + """ + Example command: + python ~/speechbot/clustering/dump_logmelfank_feats.py \ + --manifest_path /checkpoint/kushall/data/LJSpeech-1.1/asr_input_wavs_16k/train.tsv + --out_features_path /checkpoint/kushall/experiments/speechbot/logmelfbank/features/ljspeech/train.npy + """ + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + + logger.info(f"Extracting {args.feature_type} acoustic features...") + get_and_dump_features( + feature_type=args.feature_type, + checkpoint_path=args.checkpoint_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=args.sample_pct, + flatten=True, + out_features_path=args.out_features_path, + ) + logger.info(f"Saved extracted features at {args.out_features_path}") diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..dd95105232435c6cf89a8b2a571e7eadd03a98ca --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py @@ -0,0 +1,141 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +import numpy as np + +import joblib +from examples.textless_nlp.gslm.speech2unit.clustering.utils import ( + get_audio_files, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import ( + get_features, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Quantize using K-means clustering over acoustic features." + ) + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + required=True, + help="Acoustic feature type", + ) + parser.add_argument( + "--acoustic_model_path", + type=str, + help="Pretrained acoustic model checkpoint" + ) + parser.add_argument( + "--layer", + type=int, + help="The layer of the pretrained model to extract features from", + default=-1, + ) + parser.add_argument( + "--kmeans_model_path", + type=str, + required=True, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--features_path", + type=str, + default=None, + help="Features file path. You don't need to enter acoustic model details if you have dumped features", + ) + parser.add_argument( + "--manifest_path", + type=str, + default=None, + help="Manifest file containing the root dir and file names", + ) + parser.add_argument( + "--out_quantized_file_path", + required=True, + type=str, + help="File path of quantized output.", + ) + parser.add_argument( + "--extension", type=str, default=".flac", help="Features file path" + ) + parser.add_argument( + "--channel_id", + choices=['1', '2'], + help="The audio channel to extract the units in case of stereo file.", + default=None, + ) + parser.add_argument( + "--hide-fname", action='store_true', + help="Hide file names in the output file." + ) + return parser + + +def main(args, logger): + # Feature extraction + if args.features_path is not None: + logger.info(f"Loading acoustic features from {args.features_path}...") + features_batch = np.load(args.features_path) + else: + logger.info(f"Extracting {args.feature_type} acoustic features...") + features_batch = get_features( + feature_type=args.feature_type, + checkpoint_path=args.acoustic_model_path, + layer=args.layer, + manifest_path=args.manifest_path, + sample_pct=1.0, + flatten=False, + channel_id=int(args.channel_id) if args.channel_id else None, + ) + logger.info( + f"Features extracted for {len(features_batch)} utterances.\n" + ) + logger.info( + f"Dimensionality of representation = {features_batch[0].shape[1]}" + ) + + # K-means model + logger.info(f"Loading K-means model from {args.kmeans_model_path} ...") + kmeans_model = joblib.load(open(args.kmeans_model_path, "rb")) + kmeans_model.verbose = False + + _, fnames, _ = get_audio_files(args.manifest_path) + + os.makedirs(os.path.dirname(args.out_quantized_file_path), exist_ok=True) + print(f"Writing quantized predictions to {args.out_quantized_file_path}") + with open(args.out_quantized_file_path, "w") as fout: + for i, feats in enumerate(features_batch): + pred = kmeans_model.predict(feats) + pred_str = " ".join(str(p) for p in pred) + base_fname = os.path.basename(fnames[i]).rstrip('.'+args.extension.lstrip('.')) + if args.channel_id is not None: + base_fname = base_fname+f'-channel{args.channel_id}' + if not args.hide_fname: + fout.write(f"{base_fname}|{pred_str}\n") + else: + fout.write(f"{pred_str}\n") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cf08d1fe4b470477b724aa8d770d91c0cac35a0e --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/clustering/utils.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Tuple + + +def get_audio_files(manifest_path: str) -> Tuple[str, List[str], List[int]]: + fnames, sizes = [], [] + with open(manifest_path, "r") as f: + root_dir = f.readline().strip() + for line in f: + items = line.strip().split("\t") + assert ( + len(items) == 2 + ), f"File must have two columns separated by tab. Got {line}" + fnames.append(items[0]) + sizes.append(int(items[1])) + return root_dir, fnames, sizes diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..2ea3890c2896bce29a9f0c9c81579b45d69b07dc --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/cpc_feature_reader.py @@ -0,0 +1,204 @@ +import soundfile as sf +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class CpcFeatureReader: + """ + Wrapper class to run inference on CPC model. + Helps extract features for a given audio file. + """ + + def __init__( + self, + checkpoint_path, + layer, + use_encoder_layer=False, + norm_features=False, + sample_rate=16000, + max_chunk=64000, + use_cuda=True, + ): + self.model = load_cpc_model(checkpoint_path, layer).eval() + self.sample_rate = sample_rate + self.max_chunk = max_chunk + self.norm_features = norm_features + self.use_encoder_layer = use_encoder_layer + self.use_cuda = use_cuda + if self.use_cuda: + self.model.cuda() + + def read_audio(self, path, ref_len=None, channel_id=None): + wav, sr = sf.read(path) + if channel_id is not None: + assert wav.ndim == 2, \ + f"Expected stereo input when channel_id is given ({path})" + assert channel_id in [1, 2], \ + "channel_id is expected to be in [1, 2]" + wav = wav[:, channel_id-1] + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + assert sr == self.sample_rate, sr + if ref_len is not None and abs(ref_len - len(wav)) > 160: + print(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, file_path, ref_len=None, channel_id=None): + x = self.read_audio(file_path, ref_len, channel_id) + # Inspired from CPC_audio feature_loader.py + with torch.no_grad(): + x = torch.from_numpy(x).float() + if self.use_cuda: + x = x.cuda() + x = x.view(1, 1, -1) + size = x.size(2) + feat = [] + start = 0 + while start < size: + if start + self.max_chunk > size: + break + x_chunk = x[..., start : start + self.max_chunk] + feat_chunk = self.model.extract_features( + source=x_chunk, + get_encoded=self.use_encoder_layer, + norm_output=self.norm_features, + ) + feat.append(feat_chunk) + start += self.max_chunk + + if start < size: + x_chunk = x[:, -self.max_chunk :] + feat_chunk = self.model.extract_features( + source=x_chunk, + get_encoded=self.use_encoder_layer, + norm_output=self.norm_features, + ) + df = x_chunk.size(2) // feat_chunk.size(1) + delta = (size - start) // df + feat.append(feat_chunk[:, -delta:]) + return torch.cat(feat, 1).squeeze(0) + + +def load_cpc_model(checkpoint_path, layer=None): + state_dict = torch.load(checkpoint_path) + weights = state_dict["weights"] + config = state_dict["config"] + if layer is not None: + config["nLevelsGRU"] = layer + + encoder = CPCEncoder(config["hiddenEncoder"]) + ar_net = CPCAR( + config["hiddenEncoder"], config["hiddenGar"], False, config["nLevelsGRU"] + ) + + model = CPCModel(encoder, ar_net) + model.load_state_dict(weights, strict=False) + model.config = config + + return model + + +class ChannelNorm(nn.Module): + def __init__(self, num_features, epsilon=1e-05, affine=True): + super(ChannelNorm, self).__init__() + if affine: + self.weight = nn.parameter.Parameter(torch.Tensor(1, num_features, 1)) + self.bias = nn.parameter.Parameter(torch.Tensor(1, num_features, 1)) + else: + self.weight = None + self.bias = None + self.epsilon = epsilon + self.p = 0 + self.affine = affine + self.reset_parameters() + + def reset_parameters(self): + if self.affine: + torch.nn.init.ones_(self.weight) + torch.nn.init.zeros_(self.bias) + + def forward(self, x): + cum_mean = x.mean(dim=1, keepdim=True) + cum_var = x.var(dim=1, keepdim=True) + x = (x - cum_mean) * torch.rsqrt(cum_var + self.epsilon) + if self.weight is not None: + x = x * self.weight + self.bias + return x + + +class CPCEncoder(nn.Module): + def __init__(self, hidden_dim=512): + super(CPCEncoder, self).__init__() + self.conv0 = nn.Conv1d(1, hidden_dim, 10, stride=5, padding=3) + self.batchNorm0 = ChannelNorm(hidden_dim) + self.conv1 = nn.Conv1d(hidden_dim, hidden_dim, 8, stride=4, padding=2) + self.batchNorm1 = ChannelNorm(hidden_dim) + self.conv2 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1) + self.batchNorm2 = ChannelNorm(hidden_dim) + self.conv3 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1) + self.batchNorm3 = ChannelNorm(hidden_dim) + self.conv4 = nn.Conv1d(hidden_dim, hidden_dim, 4, stride=2, padding=1) + self.batchNorm4 = ChannelNorm(hidden_dim) + self.DOWNSAMPLING = 160 + + def get_output_dim(self): + return self.conv4.out_channels + + def forward(self, x): + x = F.relu(self.batchNorm0(self.conv0(x))) + x = F.relu(self.batchNorm1(self.conv1(x))) + x = F.relu(self.batchNorm2(self.conv2(x))) + x = F.relu(self.batchNorm3(self.conv3(x))) + x = F.relu(self.batchNorm4(self.conv4(x))) + return x + + +class CPCAR(nn.Module): + def __init__(self, dim_encoded, dim_output, keep_hidden, num_layers): + super(CPCAR, self).__init__() + self.baseNet = nn.LSTM( + dim_encoded, dim_output, num_layers=num_layers, batch_first=True + ) + self.hidden = None + self.keep_hidden = keep_hidden + + def get_output_dim(self): + return self.baseNet.hidden_size + + def forward(self, x): + try: + self.baseNet.flatten_parameters() + except RuntimeError: + pass + x, h = self.baseNet(x, self.hidden) + if self.keep_hidden: + if isinstance(h, tuple): + self.hidden = tuple(x.detach() for x in h) + else: + self.hidden = h.detach() + return x + + +class CPCModel(nn.Module): + def __init__(self, encoder, ar_net): + super(CPCModel, self).__init__() + self.gEncoder = encoder + self.gAR = ar_net + self.config = None + + def forward(self, x, label): + encoded = self.gEncoder(x).permute(0, 2, 1) + cpc_feature = self.gAR(encoded) + return cpc_feature, encoded, label + + def extract_features(self, source, get_encoded=False, norm_output=False): + cpc_feature, encoded, _ = self.forward(source, None) + if get_encoded: + cpc_feature = encoded + if norm_output: + mean = cpc_feature.mean(dim=1, keepdim=True) + var = cpc_feature.var(dim=1, keepdim=True) + cpc_feature = (cpc_feature - mean) / torch.sqrt(var + 1e-08) + return cpc_feature diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..4fef859fb3948fd015b90b8e82c2c8a93d3685af --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/hubert_feature_reader.py @@ -0,0 +1,70 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import fairseq +import soundfile as sf +import torch.nn.functional as F + + +class HubertFeatureReader: + """ + Wrapper class to run inference on HuBERT model. + Helps extract features for a given audio file. + """ + + def __init__(self, checkpoint_path, layer, max_chunk=1600000, use_cuda=True): + ( + model, + cfg, + task, + ) = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path] + ) + self.model = model[0].eval() + self.task = task + self.layer = layer + self.max_chunk = max_chunk + self.use_cuda = use_cuda + if self.use_cuda: + self.model.cuda() + + def read_audio(self, path, ref_len=None, channel_id=None): + wav, sr = sf.read(path) + if channel_id is not None: + assert wav.ndim == 2, \ + f"Expected stereo input when channel_id is given ({path})" + assert channel_id in [1, 2], \ + "channel_id is expected to be in [1, 2]" + wav = wav[:, channel_id-1] + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + assert sr == self.task.cfg.sample_rate, sr + if ref_len is not None and abs(ref_len - len(wav)) > 160: + print(f"ref {ref_len} != read {len(wav)} ({path})") + return wav + + def get_feats(self, file_path, ref_len=None, channel_id=None): + x = self.read_audio(file_path, ref_len, channel_id) + with torch.no_grad(): + x = torch.from_numpy(x).float() + if self.use_cuda: + x = x.cuda() + if self.task.cfg.normalize: + x = F.layer_norm(x, x.shape) + x = x.view(1, -1) + + feat = [] + for start in range(0, x.size(1), self.max_chunk): + x_chunk = x[:, start: start + self.max_chunk] + feat_chunk, _ = self.model.extract_features( + source=x_chunk, + padding_mask=None, + mask=False, + output_layer=self.layer, + ) + feat.append(feat_chunk) + return torch.cat(feat, 1).squeeze(0) diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..5879da7067479df21f7cc3536f3f32a3ba5a1986 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/logmel_feature_reader.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import soundfile as sf +import torch +import torchaudio.compliance.kaldi as kaldi + + +class LogMelFeatureReader: + """ + Wrapper class to run inference on HuBERT model. + Helps extract features for a given audio file. + """ + + def __init__(self, *args, **kwargs): + self.num_mel_bins = kwargs.get("num_mel_bins", 80) + self.frame_length = kwargs.get("frame_length", 25.0) + + def get_feats(self, file_path, channel_id=None): + wav, sr = sf.read(file_path) + if channel_id is not None: + assert wav.ndim == 2, \ + f"Expected stereo input when channel_id is given ({file_path})" + wav = wav[:, channel_id-1] + feats = torch.from_numpy(wav).float() + feats = kaldi.fbank( + feats.unsqueeze(0), + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + sample_frequency=sr, + ) + return feats diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2eca68e800e442d89af278a624354fc045e1f73a --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/utils.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import gc +import os +import random +import shutil +import numpy as np + +import torch +import tqdm +from examples.textless_nlp.gslm.speech2unit.pretrained.cpc_feature_reader import ( + CpcFeatureReader, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.hubert_feature_reader import ( + HubertFeatureReader, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.logmel_feature_reader import ( + LogMelFeatureReader, +) +from examples.textless_nlp.gslm.speech2unit.pretrained.w2v2_feature_reader import ( + Wav2VecFeatureReader, +) + + +def get_feature_reader(feature_type): + if feature_type == "logmel": + return LogMelFeatureReader + elif feature_type == "hubert": + return HubertFeatureReader + elif feature_type == "w2v2": + return Wav2VecFeatureReader + elif feature_type == "cpc": + return CpcFeatureReader + else: + raise NotImplementedError(f"{feature_type} is not supported.") + + +def get_feature_iterator( + feature_type, checkpoint_path, layer, manifest_path, sample_pct, channel_id +): + feature_reader_cls = get_feature_reader(feature_type) + with open(manifest_path, "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + file_path_list = [ + os.path.join(root, line.split("\t")[0]) + for line in lines + if len(line) > 0 + ] + if sample_pct < 1.0: + file_path_list = random.sample( + file_path_list, int(sample_pct * len(file_path_list)) + ) + num_files = len(file_path_list) + reader = feature_reader_cls( + checkpoint_path=checkpoint_path, layer=layer + ) + + def iterate(): + for file_path in file_path_list: + feats = reader.get_feats(file_path, channel_id=channel_id) + yield feats.cpu().numpy() + + return iterate, num_files + + +def get_features( + feature_type, checkpoint_path, layer, manifest_path, sample_pct, flatten, channel_id +): + generator, num_files = get_feature_iterator( + feature_type=feature_type, + checkpoint_path=checkpoint_path, + layer=layer, + manifest_path=manifest_path, + sample_pct=sample_pct, + channel_id=channel_id + ) + iterator = generator() + + features_list = [] + for features in tqdm.tqdm(iterator, total=num_files): + features_list.append(features) + + # Explicit clean up + del iterator + del generator + gc.collect() + torch.cuda.empty_cache() + + if flatten: + return np.concatenate(features_list) + + return features_list + + +def get_and_dump_features( + feature_type, + checkpoint_path, + layer, + manifest_path, + sample_pct, + flatten, + out_features_path, +): + # Feature extraction + features_batch = get_features( + feature_type=feature_type, + checkpoint_path=checkpoint_path, + layer=layer, + manifest_path=manifest_path, + sample_pct=sample_pct, + flatten=flatten, + ) + + # Save features + out_dir_path = os.path.dirname(out_features_path) + os.makedirs(out_dir_path, exist_ok=True) + shutil.copyfile( + manifest_path, + os.path.join(out_dir_path, os.path.basename(manifest_path)), + ) + np.save(out_features_path, features_batch) + + return features_batch diff --git a/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..9f9da6c4993c623fcfd48c6ef5a698b28074eaca --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/speech2unit/pretrained/w2v2_feature_reader.py @@ -0,0 +1,56 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import fairseq +import soundfile as sf + + +class Wav2VecFeatureReader: + """ + Wrapper class to run inference on Wav2Vec 2.0 model. + Helps extract features for a given audio file. + """ + + def __init__(self, checkpoint_path, layer, use_cuda=True): + state = fairseq.checkpoint_utils.load_checkpoint_to_cpu( + checkpoint_path + ) + + w2v_args = state["args"] + self.task = fairseq.tasks.setup_task(w2v_args) + model = self.task.build_model(w2v_args) + model.load_state_dict(state["model"], strict=True) + model.eval() + self.model = model + self.layer = layer + self.use_cuda = use_cuda + if self.use_cuda: + self.model.cuda() + + def read_audio(self, fname, channel_id=None): + wav, sr = sf.read(fname) + if channel_id is not None: + assert wav.ndim == 2, \ + f"Expected stereo input when channel_id is given ({fname})" + assert channel_id in [1, 2], \ + "channel_id is expected to be in [1, 2]" + wav = wav[:, channel_id-1] + if wav.ndim == 2: + wav = wav.mean(-1) + assert wav.ndim == 1, wav.ndim + assert sr == self.task.cfg.sample_rate, sr + return wav + + def get_feats(self, file_path, channel_id=None): + x = self.read_audio(file_path, channel_id) + with torch.no_grad(): + source = torch.from_numpy(x).view(1, -1).float() + if self.use_cuda: + source = source.cuda() + res = self.model( + source=source, mask=False, features_only=True, layer=self.layer + ) + return res["layer_results"][self.layer][0].squeeze(1) diff --git a/fairseq/examples/textless_nlp/gslm/tools/README.md b/fairseq/examples/textless_nlp/gslm/tools/README.md new file mode 100644 index 0000000000000000000000000000000000000000..385834841cc1376bef5bc73b2613806ed5cde38c --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/tools/README.md @@ -0,0 +1,25 @@ +# GSLM Tools + +## Resynthesis +You can use the command line tool below to input an audio file and get the resynthesized audio. This tool implements the unsupervised method for resynthesis described in the paper. The way to invoke the command line tool is shown below. +``` +FAIRSEQ_ROOT= +TYPE= +ACOUSTIC_MODEL_PATH= +LAYER= +KM_MODEL_PATH= +TTS_MODEL_PATH= +# A text file containing the codes, one per line +CODE_DICT_PATH= +WAVEGLOW_PATH= + +PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/tools/resynthesize_speech.py \ + --feature_type $TYPE \ + --acoustic_model_path $ACOUSTIC_MODEL_PATH \ + --layer $LAYER \ + --kmeans_model_path $KM_MODEL_PATH \ + --tts_model_path $TTS_MODEL_PATH \ + --code_dict_path $CODE_DICT_PATH \ + --waveglow_path $WAVEGLOW_PATH \ + --max_decoder_steps 2000 +``` \ No newline at end of file diff --git a/fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py b/fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py new file mode 100644 index 0000000000000000000000000000000000000000..309877212e19218fe400a4f12a95258ced914ed7 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/tools/resynthesize_speech.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gc +import logging +import os + +import joblib +import soundfile as sf +import torch +from examples.textless_nlp.gslm.speech2unit.pretrained.utils import get_feature_reader +from examples.textless_nlp.gslm.unit2speech.tts_data import TacotronInputDataset +from examples.textless_nlp.gslm.unit2speech.utils import ( + load_tacotron, + load_waveglow, + synthesize_audio, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser(description="GSLM U2S tool") + parser.add_argument( + "--feature_type", + type=str, + choices=["logmel", "hubert", "w2v2", "cpc"], + default=None, + required=True, + help="Acoustic feature type", + ) + parser.add_argument( + "--acoustic_model_path", + type=str, + help="Pretrained acoustic model checkpoint", + ) + parser.add_argument("--layer", type=int, help="Layer of acoustic model") + parser.add_argument( + "--kmeans_model_path", + type=str, + required=True, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--tts_model_path", + type=str, + help="TTS model file path to use for inference", + ) + parser.add_argument( + "--code_dict_path", + type=str, + help="Code dict file path to use for inference", + ) + parser.add_argument( + "--waveglow_path", + type=str, + help="Waveglow (vocoder) model file path to use for inference", + ) + parser.add_argument("--max_decoder_steps", type=int, default=2000) + parser.add_argument("--denoiser_strength", type=float, default=0.1) + return parser + + +################################################ +def main(args, logger): + # Acoustic Model + logger.info(f"Loading acoustic model from {args.tts_model_path}...") + feature_reader_cls = get_feature_reader(args.feature_type) + reader = feature_reader_cls( + checkpoint_path=args.acoustic_model_path, layer=args.layer + ) + + # K-means Model + logger.info(f"Loading K-means model from {args.kmeans_model_path} ...") + kmeans_model = joblib.load(open(args.kmeans_model_path, "rb")) + kmeans_model.verbose = False + + # TTS Model + logger.info(f"Loading TTS model from {args.tts_model_path}...") + tacotron_model, sample_rate, hparams = load_tacotron( + tacotron_model_path=args.tts_model_path, + max_decoder_steps=args.max_decoder_steps, + ) + + # Waveglow Model + logger.info(f"Loading Waveglow model from {args.waveglow_path}...") + waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path) + + # Dataset + if not os.path.exists(hparams.code_dict): + hparams.code_dict = args.code_dict_path + tts_dataset = TacotronInputDataset(hparams) + + iters = 0 + while True: + in_file_path = input("Input: Enter the full file path of audio file...\n") + out_file_path = input("Output: Enter the full file path of audio file...\n") + feats = reader.get_feats(in_file_path).cpu().numpy() + iters += 1 + if iters == 1000: + gc.collect() + torch.cuda.empty_cache() + + quantized_units = kmeans_model.predict(feats) + quantized_units_str = " ".join(map(str, quantized_units)) + + tts_input = tts_dataset.get_tensor(quantized_units_str) + mel, aud, aud_dn, has_eos = synthesize_audio( + tacotron_model, + waveglow, + denoiser, + tts_input.unsqueeze(0), + strength=args.denoiser_strength, + ) + sf.write(f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate) + logger.info("Resynthesis done!\n") + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/fairseq/examples/textless_nlp/gslm/ulm/README.md b/fairseq/examples/textless_nlp/gslm/ulm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..01459121cebefc61fdc2eae201462aa78d699111 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/ulm/README.md @@ -0,0 +1,72 @@ +# Unit Language Model (ULM) + +Here you can find links to the pre-trained ULMs and instructions on training new models using fairseq. At the end of the page, we also share how to run sampling for those models and provide pointers to the transcribed prompts we used. + +## Pre-trained models + +Using the links below, you can download pre-trained models for various unit types and vocabulary sizes: + +| | 50 | 100 | 200 +|-|-|-|- +| LogMel Filterbank | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km50/logmel50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km100/logmel100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/lm_km200/logmel200_lm.tgz) +| Modified CPC | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km50/cpc50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km100/cpc100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/lm_km200/cpc200_lm.tgz) +| HuBERT | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km50/hubert50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km100/hubert100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/lm_km200/hubert200_lm.tgz) +| Wav2Vec 2.0 | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km50/w2v2_50_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km100/w2v2_100_lm.tgz) | [download](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/lm_km200/w2v2_200_lm.tgz) + + +## Preprocessing data +Assuming that unit-transcribed train, valid, and test sets are located in `data/train.txt`, `data/valid.txt`, and `data/test.txt`, respectively, +we run the following command to get a preprocessed version of the datast in `data-bin`: + +```bash +fairseq-preprocess --only-source \ + --trainpref data/train.txt --validpref data/valid.txt --testpref data/test.txt \ + --destdir data-bin/ --workers 40 +``` +As a result, the `data-bin` directory should appear. + +## Fitting a Unit Language Model (ULM) +As an ULM, we train a standard fairseq Transformer LM. Assuming 8 GPUs used for training, a good starting point for an ULM training would be: +```bash + fairseq-train data-bin/ \ + --task=language_modeling \ + --arch=transformer_lm_big \ + --share-decoder-input-output-embed \ + --dropout=0.1 \ + --attention-dropout=0.1 \ + --optimizer=adam \ + --adam-betas='(0.9, 0.98)' \ + --clip-norm=1.0 \ + --lr=0.0005 \ + --lr-scheduler=inverse_sqrt \ + --warmup-updates=4000 \ + --warmup-init-lr=1e-07 \ + --tokens-per-sample=3072 \ + --update-freq=16 \ + --max-tokens=4096 \ + --num-workers=4 \ + --skip-invalid-size-inputs-valid-test \ + --max-update=500000 \ + --log-interval=10 \ + --seed=100501 \ + --fp16 \ + --sample-break-mode=eos +``` +This command will train a Transformer-large model (12 layers). You can train other standard LM models provided by fairseq, e.g. specify `--arch=transformer_lm` to train a smaller (6-layer) Transformer model. When training with a different number of GPUs, it might be a good idea to adjust the `update-freq` parameter. To save the GPU memory at an expense of additional computation, it can be useful to enable activation checkpointing with `--checkpoint-activations`. + +## Sampling from an ULM +Once an ULM was trained, we can use it for generating new utterances. Suppose, that the prompts are given in a file named `prompts.txt`. Then we can sample continuations by running the following command: + +```bash + python sample.py data-bin/ \ + --path=checkpoints/checkpoint_best.pt --task=language_modeling --sampling --temperature=0.7 \ + --seed=1 --prompts=prompts.txt --output=samples.txt --max-len-a=0 --max-len-b=500 \ + --prefix-size=-1 --batch-size=16 --fp16 --samples-per-prompt=10 +``` +Here, `--prefix-size` controls the number of tokens that are used to prime the ULM. When set to a positive value, the sampling script will take first `prefix-size` tokens to prompt the ULM; with `0` it runs unconditional sampling and with `-1` the entire prompt is used. +`--samples-per-prompt` specifies how many utterances are generated with every prompt which can be useful when generating multiple prompt continuations. In this command, `--max-len-a` and `--max-len-b` control the number of generated tokens. + +When using a pretrained model from above, `data-bin` should point to the unpacked directory (with `dict.txt` file). + +Evaluation-time, to generate prompts, we used utterances from LibriSpeech dev-clean and test-clean that are longer than 6s. We took first 3s from an utterance as a prompt. Unit transcripts of those prompts can be downloaded here: [[dev]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/dev_prompts.tgz) [[test]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/eval_data/test_prompts.tgz) + diff --git a/fairseq/examples/textless_nlp/gslm/ulm/sample.py b/fairseq/examples/textless_nlp/gslm/ulm/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..77302a6894cacf07588cf34fb1e695dc519d7df5 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/ulm/sample.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 -u +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Sample from a trained LM; hacked fairseq-interactive +""" +from collections import namedtuple +import os +import ast +import numpy as np + +from fairseq import checkpoint_utils, options, tasks, utils + +import tqdm + +Batch = namedtuple('Batch', 'ids src_tokens src_lengths') +Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') + + +def make_batches(lines, args, task, max_positions): + tokens = [ + task.source_dictionary.encode_line( + src_str, add_if_not_exist=False + ).long() + for src_str in lines + ] + lengths = [t.numel() for t in tokens] + itr = task.get_batch_iterator( + dataset=task.build_dataset_for_inference(tokens, lengths), + max_tokens=args.dataset.max_tokens, + max_sentences=args.dataset.batch_size, + max_positions=max_positions, + ignore_invalid_inputs=args.dataset.skip_invalid_size_inputs_valid_test + ).next_epoch_itr(shuffle=False) + for batch in itr: + yield Batch( + ids=batch['id'], + src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], + ) + + +def main(args): + arg_prompts = args.prompts + arg_output = args.output + arg_debug = args.debug + arg_sample_size = args.samples_per_prompt + + try: + from fairseq.dataclass.utils import convert_namespace_to_omegaconf + args = convert_namespace_to_omegaconf(args) + except: + pass + + # if args.max_tokens is None and args.max_sentences is None: + if args.common.seed is not None: + np.random.seed(args.common.seed) + utils.set_torch_seed(args.common.seed) + + if args.generation.sampling: + args.generation.nbest = args.generation.beam = arg_sample_size + + task = tasks.setup_task(args.task) + + overrides = ast.literal_eval(args.common_eval.model_overrides) + + models, _model_args = checkpoint_utils.load_model_ensemble( + args.common_eval.path.split(os.pathsep), + arg_overrides=overrides, + task=task, + suffix=getattr(args, "checkpoint_suffix", ""), + ) + + # Set dictionaries + src_dict = task.source_dictionary + tgt_dict = task.target_dictionary + + # Optimize ensemble for generation + for model in models: + model.prepare_for_inference_(args) + model.cuda() + + # Load alignment dictionary for unknown word replacement + # (None if no unknown word replacement, empty if no path to align dictionary) + align_dict = utils.load_align_dict(args.generation.replace_unk) + + max_positions = utils.resolve_max_positions( + task.max_positions(), + *[model.max_positions() for model in models] + ) + + output_file = open(arg_output, 'w') + + with open(arg_prompts, 'r') as fin: + lines = fin.readlines() + + split = [x.split('|', 1) for x in lines] + seq_id = [x[0] for x in split] + prompts = [x[1] for x in split] + + if args.generation.prefix_size >= 0: + prompts = [' '.join(l.split()[:args.generation.prefix_size]) + for l in prompts] + + if arg_debug: + prompts = prompts[:10] + + generator = task.build_generator(models, args.generation) + + start_id = 0 + pbar = tqdm.tqdm(total=len(prompts)) + for batch in make_batches(prompts, args, task, max_positions): + src_tokens = batch.src_tokens + src_lengths = batch.src_lengths + src_tokens = src_tokens.cuda() + src_lengths = src_lengths.cuda() + + sample = { + 'net_input': { + 'src_tokens': src_tokens, + 'src_lengths': src_lengths, + }, + } + + results = [] + translations = task.inference_step(generator, models, sample) + for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): + src_tokens_i = utils.strip_pad(src_tokens[i], tgt_dict.pad()) + results.append((i + start_id, src_tokens_i, hypos)) + + # sort output to match input order + for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): + if src_dict is not None: + src_str = src_dict.string( + src_tokens, args.common_eval.post_process) + + # Process top predictions + for hypo_id, hypo in enumerate(hypos): + _hypo_tokens, hypo_str, _alignment = utils.post_process_prediction( + hypo_tokens=hypo['tokens'].int().cpu(), + src_str=src_str, + alignment=hypo['alignment'], + align_dict=align_dict, + tgt_dict=tgt_dict, + remove_bpe=args.common_eval.post_process, + ) + + detok_hypo_str = hypo_str + utterance = detok_hypo_str + print(f'{seq_id[id]}__{hypo_id}|{utterance}', file=output_file) + pbar.update(1) + start_id += len(results) + + # output_file.close() + + +def cli_main(): + parser = options.get_interactive_generation_parser() + parser.add_argument('--prompts', type=str, default=None, required=True) + parser.add_argument('--output', type=str, default=None, required=True) + parser.add_argument('--debug', action='store_true') + parser.add_argument('--samples-per-prompt', type=int, default=1) + + args = options.parse_args_and_arch(parser) + + np.random.seed(args.seed) + utils.set_torch_seed(args.seed) + + main(args) + + +if __name__ == '__main__': + cli_main() diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/README.md b/fairseq/examples/textless_nlp/gslm/unit2speech/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e61601392bd77dc747e43182b408bb941bc4728b --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/README.md @@ -0,0 +1,40 @@ +# Unit to Speech Model (unit2speech) + +Unit to speech model is modified Tacotron2 model that learns to synthesize speech from discrete speech units. All models are trained on quantized [LJSpeech](https://keithito.com/LJ-Speech-Dataset/). + +Upstream Units | Download Links | model md5 +|-|-|- +Log Mel Filterbank + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km50/code_dict) | 932b3b8527c0125f5f964b57762eba49 +Log Mel Filterbank + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km100/code_dict) | cde0b0d278a39011d0acbd5df27abdf4 +Log Mel Filterbank + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/logmel/tts_km200/code_dict) | dba0f1d4de64bc7976718834010b23e7 +Modified CPC + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km50/code_dict) | a585e8dd8890ea56164f17635dd8e613 +Modified CPC + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km100/code_dict) | 5c0ee2869b4f483d17f37f1a41a548e0 +Modified CPC + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/cpc/tts_km200/code_dict) | 2f0c9951cf37020d9464514bff48bc5d +HuBERT Base + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km50/code_dict) | 85ffce8baec5aa90035ab696fe676fce +HuBERT Base + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km100/code_dict) | df4a9c6ffd1bb00c91405432c234aba3 +HuBERT Base + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/tts_km200/code_dict) | ac72f2c0c563589819bec116c7f8d274 +wav2vec 2.0 Large + KM50 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km50/code_dict) | e3503d0ad822b2c24b89f68b857fedff +wav2vec 2.0 Large + KM100 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km100/code_dict) | eb3666e456ae4c96bf2a1eec825c13ed +wav2vec 2.0 Large + KM200 | [model](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/tts_checkpoint_best.pt) - [code_dict](https://dl.fbaipublicfiles.com/textless_nlp/gslm/w2v2/tts_km200/code_dict) | 777d343e963c4d64f04d78eef032f4e8 + +## Run inference using a unit2speech model +* Install librosa, unidecode and inflect using `pip install librosa, unidecode, inflect` +* Download [Waveglow checkpoint](https://dl.fbaipublicfiles.com/textless_nlp/gslm/waveglow_256channels_new.pt). This is the vocoder. + +Sample commnd to run inference using trained unit2speech models. Please note that the quantized audio to synthesized should be using the same units as the unit2speech model was trained with. +``` +FAIRSEQ_ROOT= +TTS_MODEL_PATH= +QUANTIZED_UNIT_PATH= +OUT_DIR= +WAVEGLOW_PATH= +CODE_DICT_PATH= + +PYTHONPATH=${FAIRSEQ_ROOT}:${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech python ${FAIRSEQ_ROOT}/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py \ + --tts_model_path $TTS_MODEL_PATH \ + --quantized_unit_path $QUANTIZED_UNIT_PATH \ + --out_audio_dir $OUT_DIR \ + --waveglow_path $WAVEGLOW_PATH \ + --code_dict_path $CODE_DICT_PATH \ + --max_decoder_steps 2000 +``` diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py b/fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py new file mode 100644 index 0000000000000000000000000000000000000000..2be848fceae65e3bd5747a2c98106b0215c6a039 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/convert_to_16k.py @@ -0,0 +1,56 @@ +import os +import shlex +import subprocess +import progressbar +from time import time +from pathlib import Path + +def find_all_files(path_dir, extension): + out = [] + for root, dirs, filenames in os.walk(path_dir): + for f in filenames: + if f.endswith(extension): + out.append(((str(Path(f).stem)), os.path.join(root, f))) + return out + +def convert16k(inputfile, outputfile16k): + command = ('sox -c 1 -b 16 {} -t wav {} rate 16k'.format(inputfile, outputfile16k)) + subprocess.call(shlex.split(command)) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='Convert to wav 16k audio using sox.') + parser.add_argument('input_dir', type=str, + help='Path to the input dir.') + parser.add_argument('output_dir', type=str, + help='Path to the output dir.') + parser.add_argument('--extension', type=str, default='wav', + help='Audio file extension in the input. Default: mp3') + args = parser.parse_args() + + # Find all sequences + print(f"Finding all audio files with extension '{args.extension}' from {args.input_dir}...") + audio_files = find_all_files(args.input_dir, args.extension) + print(f"Done! Found {len(audio_files)} files.") + + # Convert to relative path + audio_files = [os.path.relpath(file[-1], start=args.input_dir) for file in audio_files] + + # Create all the directories needed + rel_dirs_set = set([os.path.dirname(file) for file in audio_files]) + for rel_dir in rel_dirs_set: + Path(os.path.join(args.output_dir, rel_dir)).mkdir(parents=True, exist_ok=True) + + # Converting wavs files + print("Converting the audio to wav files...") + bar = progressbar.ProgressBar(maxval=len(audio_files)) + bar.start() + start_time = time() + for index, file in enumerate(audio_files): + bar.update(index) + input_file = os.path.join(args.input_dir, file) + output_file = os.path.join(args.output_dir, os.path.splitext(file)[0]+".wav") + convert16k(input_file, output_file) + bar.finish() + print(f"...done {len(audio_files)} files in {time()-start_time} seconds.") \ No newline at end of file diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/glow.py b/fairseq/examples/textless_nlp/gslm/unit2speech/glow.py new file mode 100644 index 0000000000000000000000000000000000000000..41fd437feb17a5b40d4cbc43a0ab4e69867618bc --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/glow.py @@ -0,0 +1,312 @@ +# ***************************************************************************** +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of the NVIDIA CORPORATION nor the +# names of its contributors may be used to endorse or promote products +# derived from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY +# DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +# ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# ***************************************************************************** +import copy +import torch +from torch.autograd import Variable +import torch.nn.functional as F + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a+input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class WaveGlowLoss(torch.nn.Module): + def __init__(self, sigma=1.0): + super(WaveGlowLoss, self).__init__() + self.sigma = sigma + + def forward(self, model_output): + z, log_s_list, log_det_W_list = model_output + for i, log_s in enumerate(log_s_list): + if i == 0: + log_s_total = torch.sum(log_s) + log_det_W_total = log_det_W_list[i] + else: + log_s_total = log_s_total + torch.sum(log_s) + log_det_W_total += log_det_W_list[i] + + loss = torch.sum(z*z)/(2*self.sigma*self.sigma) - log_s_total - log_det_W_total + return loss/(z.size(0)*z.size(1)*z.size(2)) + + +class Invertible1x1Conv(torch.nn.Module): + """ + The layer outputs both the convolution, and the log determinant + of its weight matrix. If reverse=True it does convolution with + inverse + """ + def __init__(self, c): + super(Invertible1x1Conv, self).__init__() + self.conv = torch.nn.Conv1d(c, c, kernel_size=1, stride=1, padding=0, + bias=False) + + # Sample a random orthonormal matrix to initialize weights + _qr = torch.linalg.qr if torch.__version__ >= "1.8" else torch.qr + W = _qr(torch.FloatTensor(c, c).normal_())[0] + + # Ensure determinant is 1.0 not -1.0 + if torch.det(W) < 0: + W[:,0] = -1*W[:,0] + W = W.view(c, c, 1) + self.conv.weight.data = W + + def forward(self, z, reverse=False): + # shape + batch_size, group_size, n_of_groups = z.size() + + W = self.conv.weight.squeeze() + + if reverse: + if not hasattr(self, 'W_inverse'): + # Reverse computation + W_inverse = W.float().inverse() + W_inverse = Variable(W_inverse[..., None]) + if z.type() == 'torch.cuda.HalfTensor': + W_inverse = W_inverse.half() + self.W_inverse = W_inverse + z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) + return z + else: + # Forward computation + log_det_W = batch_size * n_of_groups * torch.logdet(W) + z = self.conv(z) + return z, log_det_W + + +class WN(torch.nn.Module): + """ + This is the WaveNet like layer for the affine coupling. The primary difference + from WaveNet is the convolutions need not be causal. There is also no dilation + size reset. The dilation only doubles on each layer + """ + def __init__(self, n_in_channels, n_mel_channels, n_layers, n_channels, + kernel_size): + super(WN, self).__init__() + assert(kernel_size % 2 == 1) + assert(n_channels % 2 == 0) + self.n_layers = n_layers + self.n_channels = n_channels + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + + start = torch.nn.Conv1d(n_in_channels, n_channels, 1) + start = torch.nn.utils.weight_norm(start, name='weight') + self.start = start + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(n_channels, 2*n_in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + + cond_layer = torch.nn.Conv1d(n_mel_channels, 2*n_channels*n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name='weight') + + for i in range(n_layers): + dilation = 2 ** i + padding = int((kernel_size*dilation - dilation)/2) + in_layer = torch.nn.Conv1d(n_channels, 2*n_channels, kernel_size, + dilation=dilation, padding=padding) + in_layer = torch.nn.utils.weight_norm(in_layer, name='weight') + self.in_layers.append(in_layer) + + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2*n_channels + else: + res_skip_channels = n_channels + res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name='weight') + self.res_skip_layers.append(res_skip_layer) + + def forward(self, forward_input): + audio, spect = forward_input + audio = self.start(audio) + output = torch.zeros_like(audio) + n_channels_tensor = torch.IntTensor([self.n_channels]) + + spect = self.cond_layer(spect) + + for i in range(self.n_layers): + spect_offset = i*2*self.n_channels + acts = fused_add_tanh_sigmoid_multiply( + self.in_layers[i](audio), + spect[:,spect_offset:spect_offset+2*self.n_channels,:], + n_channels_tensor) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + audio = audio + res_skip_acts[:,:self.n_channels,:] + output = output + res_skip_acts[:,self.n_channels:,:] + else: + output = output + res_skip_acts + + return self.end(output) + + +class WaveGlow(torch.nn.Module): + def __init__(self, n_mel_channels, n_flows, n_group, n_early_every, + n_early_size, WN_config): + super(WaveGlow, self).__init__() + + self.upsample = torch.nn.ConvTranspose1d(n_mel_channels, + n_mel_channels, + 1024, stride=256) + assert(n_group % 2 == 0) + self.n_flows = n_flows + self.n_group = n_group + self.n_early_every = n_early_every + self.n_early_size = n_early_size + self.WN = torch.nn.ModuleList() + self.convinv = torch.nn.ModuleList() + + n_half = int(n_group/2) + + # Set up layers with the right sizes based on how many dimensions + # have been output already + n_remaining_channels = n_group + for k in range(n_flows): + if k % self.n_early_every == 0 and k > 0: + n_half = n_half - int(self.n_early_size/2) + n_remaining_channels = n_remaining_channels - self.n_early_size + self.convinv.append(Invertible1x1Conv(n_remaining_channels)) + self.WN.append(WN(n_half, n_mel_channels*n_group, **WN_config)) + self.n_remaining_channels = n_remaining_channels # Useful during inference + + def forward(self, forward_input): + """ + forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames + forward_input[1] = audio: batch x time + """ + spect, audio = forward_input + + # Upsample spectrogram to size of audio + spect = self.upsample(spect) + assert(spect.size(2) >= audio.size(1)) + if spect.size(2) > audio.size(1): + spect = spect[:, :, :audio.size(1)] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) + + audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) + output_audio = [] + log_s_list = [] + log_det_W_list = [] + + for k in range(self.n_flows): + if k % self.n_early_every == 0 and k > 0: + output_audio.append(audio[:,:self.n_early_size,:]) + audio = audio[:,self.n_early_size:,:] + + audio, log_det_W = self.convinv[k](audio) + log_det_W_list.append(log_det_W) + + n_half = int(audio.size(1)/2) + audio_0 = audio[:,:n_half,:] + audio_1 = audio[:,n_half:,:] + + output = self.WN[k]((audio_0, spect)) + log_s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = torch.exp(log_s)*audio_1 + b + log_s_list.append(log_s) + + audio = torch.cat([audio_0, audio_1],1) + + output_audio.append(audio) + return torch.cat(output_audio,1), log_s_list, log_det_W_list + + def infer(self, spect, sigma=1.0): + spect = self.upsample(spect) + # trim conv artifacts. maybe pad spec to kernel multiple + time_cutoff = self.upsample.kernel_size[0] - self.upsample.stride[0] + spect = spect[:, :, :-time_cutoff] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) + + if spect.type() == 'torch.cuda.HalfTensor': + audio = torch.cuda.HalfTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() + else: + audio = torch.cuda.FloatTensor(spect.size(0), + self.n_remaining_channels, + spect.size(2)).normal_() + + audio = torch.autograd.Variable(sigma*audio) + + for k in reversed(range(self.n_flows)): + n_half = int(audio.size(1)/2) + audio_0 = audio[:,:n_half,:] + audio_1 = audio[:,n_half:,:] + + output = self.WN[k]((audio_0, spect)) + + s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = (audio_1 - b)/torch.exp(s) + audio = torch.cat([audio_0, audio_1],1) + + audio = self.convinv[k](audio, reverse=True) + + if k % self.n_early_every == 0 and k > 0: + if spect.type() == 'torch.cuda.HalfTensor': + z = torch.cuda.HalfTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + else: + z = torch.cuda.FloatTensor(spect.size(0), self.n_early_size, spect.size(2)).normal_() + audio = torch.cat((sigma*z, audio),1) + + audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data + return audio + + @staticmethod + def remove_weightnorm(model): + waveglow = model + for WN in waveglow.WN: + WN.start = torch.nn.utils.remove_weight_norm(WN.start) + WN.in_layers = remove(WN.in_layers) + WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) + WN.res_skip_layers = remove(WN.res_skip_layers) + return waveglow + + +def remove(conv_list): + new_conv_list = torch.nn.ModuleList() + for old_conv in conv_list: + old_conv = torch.nn.utils.remove_weight_norm(old_conv) + new_conv_list.append(old_conv) + return new_conv_list diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py b/fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py new file mode 100644 index 0000000000000000000000000000000000000000..2a287a4e97c66acbd36897b25f2ece5494005f03 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/multiproc.py @@ -0,0 +1,27 @@ +import os +import time +import torch +import sys +import subprocess + +argslist = list(sys.argv)[1:] +log_dir = argslist[-1] +num_gpus = torch.cuda.device_count() +argslist.append('--n_gpus={}'.format(num_gpus)) +workers = [] +job_id = time.strftime("%Y_%m_%d-%H%M%S") +argslist.append("--group_name=group_{}".format(job_id)) + +print("GPU log directory is {}".format(log_dir)) +os.makedirs(log_dir, exist_ok=True) +for i in range(num_gpus): + argslist.append('--rank={}'.format(i)) + stdout = None if i == 0 else open("{}/{}_GPU_{}.log".format(log_dir, job_id, i), + "w") + print(argslist) + p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) + workers.append(p) + argslist = argslist[:-1] + +for p in workers: + p.wait() diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py b/fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py new file mode 100644 index 0000000000000000000000000000000000000000..80730843bf6ae5bbb10e190df38941f39b14cb9a --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/synthesize_audio_from_units.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +import soundfile as sf +from examples.textless_nlp.gslm.unit2speech.tts_data import ( + TacotronInputDataset, +) +from examples.textless_nlp.gslm.unit2speech.utils import ( + load_quantized_audio_from_file, + load_tacotron, + load_waveglow, + synthesize_audio, +) + + +def get_logger(): + log_format = "[%(asctime)s] [%(levelname)s]: %(message)s" + logging.basicConfig(format=log_format, level=logging.INFO) + logger = logging.getLogger(__name__) + return logger + + +def get_parser(): + parser = argparse.ArgumentParser( + description="Wav2Vec 2.0 speech generator." + ) + parser.add_argument( + "--quantized_unit_path", + type=str, + help="K-means model file path to use for inference", + ) + parser.add_argument( + "--tts_model_path", + type=str, + help="TTS model file path to use for inference", + ) + parser.add_argument( + "--waveglow_path", + type=str, + help="Path to the waveglow checkpoint (vocoder).", + ) + parser.add_argument( + "--code_dict_path", + type=str, + help="Code dict file path to use for inference", + ) + parser.add_argument("--max_decoder_steps", type=int, default=2000) + parser.add_argument("--denoiser_strength", type=float, default=0.1) + parser.add_argument( + "--out_audio_dir", + type=str, + help="Output directory to dump audio files", + ) + + return parser + + +def main(args, logger): + # Load quantized audio + logger.info(f"Loading quantized audio from {args.quantized_unit_path}...") + names_batch, quantized_units_batch = load_quantized_audio_from_file( + file_path=args.quantized_unit_path + ) + + logger.info(f"Loading TTS model from {args.tts_model_path}...") + tacotron_model, sample_rate, hparams = load_tacotron( + tacotron_model_path=args.tts_model_path, + max_decoder_steps=args.max_decoder_steps, + ) + + logger.info(f"Loading Waveglow model from {args.waveglow_path}...") + waveglow, denoiser = load_waveglow(waveglow_path=args.waveglow_path) + + if not os.path.exists(hparams.code_dict): + hparams.code_dict = args.code_dict_path + tts_dataset = TacotronInputDataset(hparams) + + for name, quantized_units in zip(names_batch, quantized_units_batch): + quantized_units_str = " ".join(map(str, quantized_units)) + tts_input = tts_dataset.get_tensor(quantized_units_str) + mel, aud, aud_dn, has_eos = synthesize_audio( + tacotron_model, + waveglow, + denoiser, + tts_input.unsqueeze(0), + strength=args.denoiser_strength, + ) + out_file_path = os.path.join(args.out_audio_dir, f"{name}.wav") + sf.write( + f"{out_file_path}", aud_dn[0].cpu().float().numpy(), sample_rate + ) + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + logger = get_logger() + logger.info(args) + main(args, logger) diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..b5af7f723eb8047bc58db2f85234aea161fbc659 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/audio_processing.py @@ -0,0 +1,93 @@ +import torch +import numpy as np +from scipy.signal import get_window +import librosa.util as librosa_util + + +def window_sumsquare(window, n_frames, hop_length=200, win_length=800, + n_fft=800, dtype=np.float32, norm=None): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm)**2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py new file mode 100644 index 0000000000000000000000000000000000000000..e2e35c1a8cc4c628c5d05802677142c9a2122d2b --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cleaners.py @@ -0,0 +1,90 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re +from unidecode import unidecode +from .numbers import normalize_numbers + + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), +]] + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py new file mode 100644 index 0000000000000000000000000000000000000000..62bfef745c30a56f7b6605d9e3becfbc40edb50d --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/cmudict.py @@ -0,0 +1,65 @@ +""" from https://github.com/keithito/tacotron """ + +import re + + +valid_symbols = [ + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', 'AH2', + 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', 'AY1', 'AY2', + 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', 'ER1', 'ER2', 'EY', + 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', 'IH1', 'IH2', 'IY', 'IY0', 'IY1', + 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', + 'OY1', 'OY2', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', + 'UW0', 'UW1', 'UW2', 'V', 'W', 'Y', 'Z', 'ZH' +] + +_valid_symbol_set = set(valid_symbols) + + +class CMUDict: + '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding='latin-1') as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = {word: pron for word, pron in entries.items() if len(pron) == 1} + self._entries = entries + + + def __len__(self): + return len(self._entries) + + + def lookup(self, word): + '''Returns list of ARPAbet pronunciations of the given word.''' + return self._entries.get(word.upper()) + + + +_alt_re = re.compile(r'\([0-9]+\)') + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): + parts = line.split(' ') + word = re.sub(_alt_re, '', parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(' ') + for part in parts: + if part not in _valid_symbol_set: + return None + return ' '.join(parts) diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f10d557ff5a4fff03b94f81543bd58cf1a66bc8f --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/layers.py @@ -0,0 +1,103 @@ +import torch +from librosa.filters import mel as librosa_mel_fn +from .audio_processing import dynamic_range_compression +from .audio_processing import dynamic_range_decompression +from .stft import STFT +from .utils import get_mask_from_lengths + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm, self).__init__() + if padding is None: + assert(kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +class GlobalAvgPool(torch.nn.Module): + def __init__(self): + super(GlobalAvgPool, self).__init__() + + def forward(self, x, lengths=None): + """Average pooling across time steps (dim=1) with optionally lengths. + Args: + x: torch.Tensor of shape (N, T, ...) + lengths: None or torch.Tensor of shape (N,) + dim: dimension to pool + """ + if lengths is None: + return x.mean(dim=1, keepdim=False) + else: + mask = get_mask_from_lengths(lengths).type(x.type()).to(x.device) + mask_shape = list(mask.size()) + [1 for _ in range(x.ndimension()-2)] + mask = mask.reshape(*mask_shape) + numer = (x * mask).sum(dim=1, keepdim=False) + denom = mask.sum(dim=1, keepdim=False) + return numer / denom + + +class TacotronSTFT(torch.nn.Module): + def __init__(self, filter_length=1024, hop_length=256, win_length=1024, + n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0, + mel_fmax=8000.0): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer('mel_basis', mel_basis) + + def spectral_normalize(self, magnitudes): + output = dynamic_range_compression(magnitudes) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert(torch.min(y.data) >= -1) + assert(torch.max(y.data) <= 1) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output) + return mel_output diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py new file mode 100644 index 0000000000000000000000000000000000000000..ccf132b150a7cc1c125c1190b5fd8f43edaae685 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/model.py @@ -0,0 +1,669 @@ +from math import sqrt +import torch +import torch.distributions as distr +from torch.autograd import Variable +from torch import nn +from torch.nn import functional as F +from .layers import ConvNorm, LinearNorm, GlobalAvgPool +from .utils import to_gpu, get_mask_from_lengths + + +class LocationLayer(nn.Module): + def __init__(self, attention_n_filters, attention_kernel_size, + attention_dim): + super(LocationLayer, self).__init__() + padding = int((attention_kernel_size - 1) / 2) + self.location_conv = ConvNorm(2, attention_n_filters, + kernel_size=attention_kernel_size, + padding=padding, bias=False, stride=1, + dilation=1) + self.location_dense = LinearNorm(attention_n_filters, attention_dim, + bias=False, w_init_gain='tanh') + + def forward(self, attention_weights_cat): + processed_attention = self.location_conv(attention_weights_cat) + processed_attention = processed_attention.transpose(1, 2) + processed_attention = self.location_dense(processed_attention) + return processed_attention + + +class Attention(nn.Module): + def __init__(self, attention_rnn_dim, embedding_dim, attention_dim, + attention_location_n_filters, attention_location_kernel_size): + super(Attention, self).__init__() + self.query_layer = LinearNorm(attention_rnn_dim, attention_dim, + bias=False, w_init_gain='tanh') + self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False, + w_init_gain='tanh') + self.v = LinearNorm(attention_dim, 1, bias=False) + self.location_layer = LocationLayer(attention_location_n_filters, + attention_location_kernel_size, + attention_dim) + self.score_mask_value = -float("inf") + + def get_alignment_energies(self, query, processed_memory, + attention_weights_cat): + """ + PARAMS + ------ + query: decoder output (batch, n_mel_channels * n_frames_per_step) + processed_memory: processed encoder outputs (B, T_in, attention_dim) + attention_weights_cat: cumulative and prev. att weights (B, 2, max_time) + + RETURNS + ------- + alignment (batch, max_time) + """ + + processed_query = self.query_layer(query.unsqueeze(1)) + processed_attention_weights = self.location_layer(attention_weights_cat) + energies = self.v(torch.tanh( + processed_query + processed_attention_weights + processed_memory)) + + energies = energies.squeeze(-1) + return energies + + def forward(self, attention_hidden_state, memory, processed_memory, + attention_weights_cat, mask): + """ + PARAMS + ------ + attention_hidden_state: attention rnn last output + memory: encoder outputs + processed_memory: processed encoder outputs + attention_weights_cat: previous and cummulative attention weights + mask: binary mask for padded data + """ + alignment = self.get_alignment_energies( + attention_hidden_state, processed_memory, attention_weights_cat) + + if mask is not None: + alignment.data.masked_fill_(mask, self.score_mask_value) + + attention_weights = F.softmax(alignment, dim=1) + attention_context = torch.bmm(attention_weights.unsqueeze(1), memory) + attention_context = attention_context.squeeze(1) + + return attention_context, attention_weights + + +class Prenet(nn.Module): + def __init__(self, in_dim, sizes): + super(Prenet, self).__init__() + in_sizes = [in_dim] + sizes[:-1] + self.layers = nn.ModuleList( + [LinearNorm(in_size, out_size, bias=False) + for (in_size, out_size) in zip(in_sizes, sizes)]) + + def forward(self, x): + for linear in self.layers: + x = F.dropout(F.relu(linear(x)), p=0.5, training=True) + return x + + +class Postnet(nn.Module): + """Postnet + - Five 1-d convolution with 512 channels and kernel size 5 + """ + + def __init__(self, hparams): + super(Postnet, self).__init__() + self.convolutions = nn.ModuleList() + + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ) + + for i in range(1, hparams.postnet_n_convolutions - 1): + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.postnet_embedding_dim, + hparams.postnet_embedding_dim, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.postnet_embedding_dim)) + ) + + self.convolutions.append( + nn.Sequential( + ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels, + kernel_size=hparams.postnet_kernel_size, stride=1, + padding=int((hparams.postnet_kernel_size - 1) / 2), + dilation=1, w_init_gain='linear'), + nn.BatchNorm1d(hparams.n_mel_channels)) + ) + + def forward(self, x): + for i in range(len(self.convolutions) - 1): + x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training) + x = F.dropout(self.convolutions[-1](x), 0.5, self.training) + + return x + + +class Encoder(nn.Module): + """Encoder module: + - Three 1-d convolution banks + - Bidirectional LSTM + """ + def __init__(self, hparams): + super(Encoder, self).__init__() + + convolutions = [] + for _ in range(hparams.encoder_n_convolutions): + conv_layer = nn.Sequential( + ConvNorm(hparams.encoder_embedding_dim, + hparams.encoder_embedding_dim, + kernel_size=hparams.encoder_kernel_size, stride=1, + padding=int((hparams.encoder_kernel_size - 1) / 2), + dilation=1, w_init_gain='relu'), + nn.BatchNorm1d(hparams.encoder_embedding_dim)) + convolutions.append(conv_layer) + self.convolutions = nn.ModuleList(convolutions) + + self.lstm = nn.LSTM(hparams.encoder_embedding_dim, + int(hparams.encoder_embedding_dim / 2), 1, + batch_first=True, bidirectional=True) + + def forward(self, x, input_lengths): + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + # pytorch tensor are not reversible, hence the conversion + input_lengths = input_lengths.cpu().numpy() + x = nn.utils.rnn.pack_padded_sequence( + x, input_lengths, batch_first=True) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + + outputs, _ = nn.utils.rnn.pad_packed_sequence( + outputs, batch_first=True) + + return outputs + + def inference(self, x): + for conv in self.convolutions: + x = F.dropout(F.relu(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + + return outputs + + +class AudioEncoder(nn.Module): + def __init__(self, hparams): + super(AudioEncoder, self).__init__() + + assert hparams.lat_dim > 0 + + convolutions = [] + inp_dim = hparams.n_mel_channels + for _ in range(hparams.lat_n_convolutions): + conv_layer = nn.Sequential( + ConvNorm(inp_dim, hparams.lat_n_filters, + kernel_size=hparams.lat_kernel_size, stride=1, + padding=int((hparams.lat_kernel_size - 1) / 2), + dilation=1, w_init_gain='tanh'), + nn.BatchNorm1d(hparams.lat_n_filters)) + inp_dim = hparams.lat_n_filters + convolutions.append(conv_layer) + self.convolutions = nn.ModuleList(convolutions) + + self.lstm = nn.LSTM(hparams.lat_n_filters, + int(hparams.lat_n_filters / 2), + hparams.lat_n_blstms, batch_first=True, + bidirectional=True) + self.pool = GlobalAvgPool() + + self.mu_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim) + self.logvar_proj = LinearNorm(hparams.lat_n_filters, hparams.lat_dim) + self.lat_dim = hparams.lat_dim + + def forward(self, x, lengths): + """ + Args: + x (torch.Tensor): (B, F, T) + """ + + for conv in self.convolutions: + x = F.dropout(F.tanh(conv(x)), 0.5, self.training) + + x = x.transpose(1, 2) # (B, T, D) + + # x may not be sorted by length. Sort->process->unsort + max_len = x.size(1) + assert max_len == torch.max(lengths).item() + + lengths, perm_idx = lengths.sort(0, descending=True) + x = x[perm_idx] + x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True) + + self.lstm.flatten_parameters() + outputs, _ = self.lstm(x) + outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True) + + _, unperm_idx = perm_idx.sort(0) + outputs = outputs[unperm_idx] # (B, T, D) + lengths = lengths[unperm_idx] # (B, T, D) + + outputs = self.pool(outputs, lengths) # (B, D) + + mu = self.mu_proj(outputs) + logvar = self.logvar_proj(outputs) + z = distr.Normal(mu, logvar).rsample() + return z, mu, logvar + + +class Decoder(nn.Module): + def __init__(self, hparams): + super(Decoder, self).__init__() + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + self.encoder_embedding_dim = hparams.encoder_embedding_dim + self.obs_dim = hparams.obs_dim + self.lat_dim = hparams.lat_dim + self.attention_rnn_dim = hparams.attention_rnn_dim + self.decoder_rnn_dim = hparams.decoder_rnn_dim + self.prenet_dim = hparams.prenet_dim + self.max_decoder_steps = hparams.max_decoder_steps + self.gate_threshold = hparams.gate_threshold + self.p_attention_dropout = hparams.p_attention_dropout + self.p_decoder_dropout = hparams.p_decoder_dropout + + self.prenet = Prenet( + hparams.n_mel_channels * hparams.n_frames_per_step, + [hparams.prenet_dim, hparams.prenet_dim]) + + self.attention_rnn = nn.LSTMCell( + hparams.prenet_dim + hparams.encoder_embedding_dim, + hparams.attention_rnn_dim) + + self.attention_layer = Attention( + hparams.attention_rnn_dim, hparams.encoder_embedding_dim, + hparams.attention_dim, hparams.attention_location_n_filters, + hparams.attention_location_kernel_size) + + encoder_tot_dim = (hparams.encoder_embedding_dim + \ + hparams.lat_dim + hparams.obs_dim) + self.decoder_rnn = nn.LSTMCell( + hparams.attention_rnn_dim + encoder_tot_dim, + hparams.decoder_rnn_dim, 1) + + self.linear_projection = LinearNorm( + hparams.decoder_rnn_dim + encoder_tot_dim, + hparams.n_mel_channels * hparams.n_frames_per_step) + + self.gate_layer = LinearNorm( + hparams.decoder_rnn_dim + encoder_tot_dim, 1, + bias=True, w_init_gain='sigmoid') + + def get_go_frame(self, memory): + """ Gets all zeros frames to use as first decoder input + PARAMS + ------ + memory: decoder outputs + + RETURNS + ------- + decoder_input: all zeros frames + """ + B = memory.size(0) + decoder_input = Variable(memory.data.new( + B, self.n_mel_channels * self.n_frames_per_step).zero_()) + return decoder_input + + def initialize_decoder_states(self, memory, obs_and_lat, mask): + """ Initializes attention rnn states, decoder rnn states, attention + weights, attention cumulative weights, attention context, stores memory + and stores processed memory + PARAMS + ------ + memory: Encoder outputs + obs_and_lat: Observed and latent attribute embeddings + mask: Mask for padded data if training, expects None for inference + """ + B = memory.size(0) + MAX_TIME = memory.size(1) + + self.attention_hidden = Variable(memory.data.new( + B, self.attention_rnn_dim).zero_()) + self.attention_cell = Variable(memory.data.new( + B, self.attention_rnn_dim).zero_()) + + self.decoder_hidden = Variable(memory.data.new( + B, self.decoder_rnn_dim).zero_()) + self.decoder_cell = Variable(memory.data.new( + B, self.decoder_rnn_dim).zero_()) + + self.attention_weights = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_weights_cum = Variable(memory.data.new( + B, MAX_TIME).zero_()) + self.attention_context = Variable(memory.data.new( + B, self.encoder_embedding_dim).zero_()) + + self.memory = memory + self.processed_memory = self.attention_layer.memory_layer(memory) + self.obs_and_lat = obs_and_lat + self.mask = mask + + def parse_decoder_inputs(self, decoder_inputs): + """ Prepares decoder inputs, i.e. mel outputs + PARAMS + ------ + decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs + + RETURNS + ------- + inputs: processed decoder inputs + + """ + # (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels) + decoder_inputs = decoder_inputs.transpose(1, 2) + decoder_inputs = decoder_inputs.view( + decoder_inputs.size(0), + int(decoder_inputs.size(1)/self.n_frames_per_step), -1) + # (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels) + decoder_inputs = decoder_inputs.transpose(0, 1) + return decoder_inputs + + def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments): + """ Prepares decoder outputs for output + PARAMS + ------ + mel_outputs: + gate_outputs: gate output energies + alignments: + + RETURNS + ------- + mel_outputs: + gate_outpust: gate output energies + alignments: + """ + # (T_out, B) -> (B, T_out) + alignments = torch.stack(alignments).transpose(0, 1) + # (T_out, B) -> (B, T_out) + gate_outputs = torch.stack(gate_outputs).transpose(0, 1) + gate_outputs = gate_outputs.contiguous() + # (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels) + mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous() + # decouple frames per step + mel_outputs = mel_outputs.view( + mel_outputs.size(0), -1, self.n_mel_channels) + # (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out) + mel_outputs = mel_outputs.transpose(1, 2) + + return mel_outputs, gate_outputs, alignments + + def decode(self, decoder_input): + """ Decoder step using stored states, attention and memory + PARAMS + ------ + decoder_input: previous mel output + + RETURNS + ------- + mel_output: + gate_output: gate output energies + attention_weights: + """ + cell_input = torch.cat((decoder_input, self.attention_context), -1) + self.attention_hidden, self.attention_cell = self.attention_rnn( + cell_input, (self.attention_hidden, self.attention_cell)) + self.attention_hidden = F.dropout( + self.attention_hidden, self.p_attention_dropout, self.training) + + attention_weights_cat = torch.cat( + (self.attention_weights.unsqueeze(1), + self.attention_weights_cum.unsqueeze(1)), dim=1) + self.attention_context, self.attention_weights = self.attention_layer( + self.attention_hidden, self.memory, self.processed_memory, + attention_weights_cat, self.mask) + + self.attention_weights_cum += self.attention_weights + decoder_input = torch.cat( + (self.attention_hidden, self.attention_context), -1) + if self.obs_and_lat is not None: + decoder_input = torch.cat((decoder_input, self.obs_and_lat), -1) + self.decoder_hidden, self.decoder_cell = self.decoder_rnn( + decoder_input, (self.decoder_hidden, self.decoder_cell)) + self.decoder_hidden = F.dropout( + self.decoder_hidden, self.p_decoder_dropout, self.training) + + decoder_hidden_attention_context = torch.cat( + (self.decoder_hidden, self.attention_context), dim=1) + if self.obs_and_lat is not None: + decoder_hidden_attention_context = torch.cat( + (decoder_hidden_attention_context, self.obs_and_lat), dim=1) + decoder_output = self.linear_projection( + decoder_hidden_attention_context) + + gate_prediction = self.gate_layer(decoder_hidden_attention_context) + return decoder_output, gate_prediction, self.attention_weights + + def forward(self, memory, obs_and_lat, decoder_inputs, memory_lengths): + """ Decoder forward pass for training + PARAMS + ------ + memory: Encoder outputs + obs_and_lat: Observed and latent attribute embeddings + decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs + memory_lengths: Encoder output lengths for attention masking. + + RETURNS + ------- + mel_outputs: mel outputs from the decoder + gate_outputs: gate outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + + decoder_input = self.get_go_frame(memory).unsqueeze(0) + decoder_inputs = self.parse_decoder_inputs(decoder_inputs) + decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0) + decoder_inputs = self.prenet(decoder_inputs) + + self.initialize_decoder_states( + memory, obs_and_lat, mask=~get_mask_from_lengths(memory_lengths)) + + mel_outputs, gate_outputs, alignments = [], [], [] + while len(mel_outputs) < decoder_inputs.size(0) - 1: + decoder_input = decoder_inputs[len(mel_outputs)] + mel_output, gate_output, attention_weights = self.decode( + decoder_input) + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output.squeeze()] + alignments += [attention_weights] + + mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( + mel_outputs, gate_outputs, alignments) + + return mel_outputs, gate_outputs, alignments + + def inference(self, memory, obs_and_lat, ret_has_eos=False): + """ Decoder inference + PARAMS + ------ + memory: Encoder outputs + obs_and_lat: Observed and latent attribute embeddings + + RETURNS + ------- + mel_outputs: mel outputs from the decoder + gate_outputs: gate outputs from the decoder + alignments: sequence of attention weights from the decoder + """ + decoder_input = self.get_go_frame(memory) + + self.initialize_decoder_states(memory, obs_and_lat, mask=None) + + mel_outputs, gate_outputs, alignments = [], [], [] + has_eos = False + while True: + decoder_input = self.prenet(decoder_input) + mel_output, gate_output, alignment = self.decode(decoder_input) + + mel_outputs += [mel_output.squeeze(1)] + gate_outputs += [gate_output] + alignments += [alignment] + + if torch.sigmoid(gate_output.data) > self.gate_threshold: + has_eos = True + break + elif len(mel_outputs) == self.max_decoder_steps: + # print("Warning! Reached max decoder steps") + break + + decoder_input = mel_output + + mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs( + mel_outputs, gate_outputs, alignments) + + if ret_has_eos: + return mel_outputs, gate_outputs, alignments, has_eos + else: + return mel_outputs, gate_outputs, alignments + + +class Tacotron2(nn.Module): + def __init__(self, hparams): + super(Tacotron2, self).__init__() + self.mask_padding = hparams.mask_padding + self.fp16_run = hparams.fp16_run + self.n_mel_channels = hparams.n_mel_channels + self.n_frames_per_step = hparams.n_frames_per_step + + # initialize text encoder embedding + self.embedding = nn.Embedding( + hparams.n_symbols, hparams.symbols_embedding_dim) + std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.embedding.weight.data.uniform_(-val, val) + + # initialize observed attribute embedding + self.obs_embedding = None + if hparams.obs_dim > 0: + self.obs_embedding = nn.Embedding( + hparams.obs_n_class, hparams.obs_dim) + std = sqrt(2.0 / (hparams.obs_n_class + hparams.obs_dim)) + val = sqrt(3.0) * std # uniform bounds for std + self.obs_embedding.weight.data.uniform_(-val, val) + + self.encoder = Encoder(hparams) + self.decoder = Decoder(hparams) + self.postnet = Postnet(hparams) + + self.lat_encoder = None + if hparams.lat_dim > 0: + self.lat_encoder = AudioEncoder(hparams) + + def parse_batch(self, batch): + (text_padded, input_lengths, obs_labels, + mel_padded, gate_padded, output_lengths) = batch + text_padded = to_gpu(text_padded).long() + input_lengths = to_gpu(input_lengths).long() + obs_labels = to_gpu(obs_labels).long() + max_len = torch.max(input_lengths.data).item() + mel_padded = to_gpu(mel_padded).float() + gate_padded = to_gpu(gate_padded).float() + output_lengths = to_gpu(output_lengths).long() + + return ( + (text_padded, input_lengths, obs_labels, + mel_padded, max_len, output_lengths), + (mel_padded, gate_padded)) + + def parse_output(self, outputs, output_lengths=None): + if self.mask_padding and output_lengths is not None: + mask = ~get_mask_from_lengths(output_lengths) + mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1)) + mask = mask.permute(1, 0, 2) + + outputs[0].data.masked_fill_(mask, 0.0) + outputs[1].data.masked_fill_(mask, 0.0) + outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies + + return outputs + + def forward(self, inputs): + (text_inputs, text_lengths, obs_labels, + mels, max_len, output_lengths) = inputs + text_lengths, output_lengths = text_lengths.data, output_lengths.data + + embedded_inputs = self.embedding(text_inputs).transpose(1, 2) + + encoder_outputs = self.encoder(embedded_inputs, text_lengths) + + obs = None + if self.obs_embedding is not None: + obs = self.obs_embedding(obs_labels) + + lat, lat_mu, lat_logvar = None, None, None + if self.lat_encoder is not None: + (lat, lat_mu, lat_logvar) = self.lat_encoder(mels, output_lengths) + + obs_and_lat = [x for x in [obs, lat] if x is not None] + if bool(obs_and_lat): + obs_and_lat = torch.cat(obs_and_lat, dim=-1) + else: + obs_and_lat = None + + mel_outputs, gate_outputs, alignments = self.decoder( + encoder_outputs, obs_and_lat, mels, memory_lengths=text_lengths) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + return self.parse_output( + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments, + lat_mu, lat_logvar], + output_lengths) + + def inference(self, inputs, obs_labels=None, lat=None, ret_has_eos=False): + embedded_inputs = self.embedding(inputs).transpose(1, 2) + encoder_outputs = self.encoder.inference(embedded_inputs) + + if obs_labels is None: + obs_labels = torch.LongTensor(len(inputs)) + obs_labels = obs_labels.to(inputs.device).zero_() + + obs = None + if self.obs_embedding is not None: + obs = self.obs_embedding(obs_labels) + + if self.lat_encoder is not None: + if lat is None: + lat = torch.FloatTensor(len(inputs), self.lat_encoder.lat_dim) + lat = lat.to(inputs.device).zero_().type(encoder_outputs.type()) + + obs_and_lat = [x for x in [obs, lat] if x is not None] + if bool(obs_and_lat): + obs_and_lat = torch.cat(obs_and_lat, dim=-1) + else: + obs_and_lat = None + + mel_outputs, gate_outputs, alignments, has_eos = self.decoder.inference( + encoder_outputs, obs_and_lat, ret_has_eos=True) + + mel_outputs_postnet = self.postnet(mel_outputs) + mel_outputs_postnet = mel_outputs + mel_outputs_postnet + + outputs = self.parse_output( + [mel_outputs, mel_outputs_postnet, gate_outputs, alignments]) + + if ret_has_eos: + return outputs + [has_eos] + else: + return outputs diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py new file mode 100644 index 0000000000000000000000000000000000000000..0d5f7fa818a45ecf132627d240afac653e148070 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/numbers.py @@ -0,0 +1,71 @@ +""" from https://github.com/keithito/tacotron """ + +import inflect +import re + + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..63fcd431e2d7746b696aaa0d4172bc04ffb88efa --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/stft.py @@ -0,0 +1,141 @@ +""" +BSD 3-Clause License + +Copyright (c) 2017, Prem Seetharaman +All rights reserved. + +* Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from this + software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + +import torch +import numpy as np +import torch.nn.functional as F +from torch.autograd import Variable +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from .audio_processing import window_sumsquare + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + def __init__(self, filter_length=800, hop_length=200, win_length=800, + window='hann'): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]), + np.imag(fourier_basis[:cutoff, :])]) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :]) + + if window is not None: + assert(filter_length >= win_length) + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer('forward_basis', forward_basis.float()) + self.register_buffer('inverse_basis', inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode='reflect') + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable( + torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, magnitude.size(-1), hop_length=self.hop_length, + win_length=self.win_length, n_fft=self.filter_length, + dtype=np.float32) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0]) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length/2):] + inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0d70fdad92ba4f554d971710b60f2f9e8d9298 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/symbols.py @@ -0,0 +1,18 @@ +""" from https://github.com/keithito/tacotron """ + +''' +Defines the set of symbols used in text input to the model. + +The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' +from . import cmudict + +_pad = '_' +_punctuation = '!\'(),.:;? ' +_special = '-' +_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz' + +# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): +_arpabet = ['@' + s for s in cmudict.valid_symbols] + +# Export all symbols: +symbols = [_pad] + list(_special) + list(_punctuation) + list(_letters) + _arpabet diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py new file mode 100644 index 0000000000000000000000000000000000000000..49e2ca498bf67ad226af5de796b9f44afa65198d --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/text.py @@ -0,0 +1,107 @@ +""" from https://github.com/keithito/tacotron """ +import numpy as np +import re +from . import cleaners +from .symbols import symbols + + +# Mappings from symbol to numeric ID and vice versa: +_symbol_to_id = {s: i for i, s in enumerate(symbols)} +_id_to_symbol = {i: s for i, s in enumerate(symbols)} + +# Regular expression matching text enclosed in curly braces: +_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + +# Special symbols +SOS_TOK = '' +EOS_TOK = '' + +def text_to_sequence(text, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = _curly_re.match(text) + if not m: + sequence += _symbols_to_sequence(_clean_text(text, cleaner_names)) + break + sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names)) + sequence += _arpabet_to_sequence(m.group(2)) + text = m.group(3) + + return sequence + + +def sample_code_chunk(code, size): + assert(size > 0 and size <= len(code)) + start = np.random.randint(len(code) - size + 1) + end = start + size + return code[start:end], start, end + + +def code_to_sequence(code, code_dict, collapse_code): + if collapse_code: + prev_c = None + sequence = [] + for c in code: + if c in code_dict and c != prev_c: + sequence.append(code_dict[c]) + prev_c = c + else: + sequence = [code_dict[c] for c in code if c in code_dict] + if len(sequence) < 0.95 * len(code): + print('WARNING : over 5%% codes are OOV') + + return sequence + + +def sequence_to_text(sequence): + '''Converts a sequence of IDs back to a string''' + result = '' + for symbol_id in sequence: + if symbol_id in _id_to_symbol: + s = _id_to_symbol[symbol_id] + # Enclose ARPAbet back in curly braces: + if len(s) > 1 and s[0] == '@': + s = '{%s}' % s[1:] + result += s + return result.replace('}{', ' ') + + +def sequence_to_code(sequence, code_dict): + '''Analogous to sequence_to_text''' + id_to_code = {i: c for c, i in code_dict.items()} + return ' '.join([id_to_code[i] for i in sequence]) + + +def _clean_text(text, cleaner_names): + for name in cleaner_names: + cleaner = getattr(cleaners, name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text + + +def _symbols_to_sequence(symbols): + return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)] + + +def _arpabet_to_sequence(text): + return _symbols_to_sequence(['@' + s for s in text.split()]) + + +def _should_keep_symbol(s): + return s in _symbol_to_id and s != '_' and s != '~' diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b72ae0e35bfd68a1f1a5f153a536c033810b08be --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/utils.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import io +import json +import librosa +import numpy as np +import soundfile as sf +import time +import torch +from scipy.io.wavfile import read +from .text import SOS_TOK, EOS_TOK + + +def get_mask_from_lengths(lengths): + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len)) + mask = (ids < lengths.unsqueeze(1)) + return mask + + +def load_wav_to_torch(full_path, sr=None): + data, sr = librosa.load(full_path, sr=sr) + data = np.clip(data, -1, 1) # potentially out of [-1, 1] due to resampling + data = data * 32768.0 # match values loaded by scipy + return torch.FloatTensor(data.astype(np.float32)), sr + + +def read_binary_audio(bin_data, tar_sr=None): + """ + read binary audio (`bytes` or `uint8` `numpy.ndarray`) to `float32` + `numpy.ndarray` + + RETURNS: + data (np.ndarray) : audio of shape (n,) or (2, n) + tar_sr (int) : sample rate + """ + data, ori_sr = sf.read(io.BytesIO(bin_data), dtype='float32') + data = data.T + if (tar_sr is not None) and (ori_sr != tar_sr): + data = librosa.resample(data, ori_sr, tar_sr) + else: + tar_sr = ori_sr + data = np.clip(data, -1, 1) + data = data * 32768.0 + return torch.FloatTensor(data.astype(np.float32)), tar_sr + + +def load_filepaths_and_text(filename): + with open(filename, encoding='utf-8') as f: + data = [json.loads(line.rstrip()) for line in f] + return data + + +def to_gpu(x): + x = x.contiguous() + + if torch.cuda.is_available(): + x = x.cuda(non_blocking=True) + return torch.autograd.Variable(x) + + +def load_code_dict(path, add_sos=False, add_eos=False): + if not path: + return {} + + with open(path, 'r') as f: + codes = ['_'] + [line.rstrip() for line in f] # '_' for pad + code_dict = {c: i for i, c in enumerate(codes)} + + if add_sos: + code_dict[SOS_TOK] = len(code_dict) + if add_eos: + code_dict[EOS_TOK] = len(code_dict) + assert(set(code_dict.values()) == set(range(len(code_dict)))) + + return code_dict + + +def load_obs_label_dict(path): + if not path: + return {} + with open(path, 'r') as f: + obs_labels = [line.rstrip() for line in f] + return {c: i for i, c in enumerate(obs_labels)} + + +# A simple timer class inspired from `tnt.TimeMeter` +class CudaTimer: + def __init__(self, keys): + self.keys = keys + self.reset() + + def start(self, key): + s = torch.cuda.Event(enable_timing=True) + s.record() + self.start_events[key].append(s) + return self + + def stop(self, key): + e = torch.cuda.Event(enable_timing=True) + e.record() + self.end_events[key].append(e) + return self + + def reset(self): + self.start_events = collections.defaultdict(list) + self.end_events = collections.defaultdict(list) + self.running_times = collections.defaultdict(float) + self.n = collections.defaultdict(int) + return self + + def value(self): + self._synchronize() + return {k: self.running_times[k] / self.n[k] for k in self.keys} + + def _synchronize(self): + torch.cuda.synchronize() + for k in self.keys: + starts = self.start_events[k] + ends = self.end_events[k] + if len(starts) == 0: + raise ValueError("Trying to divide by zero in TimeMeter") + if len(ends) != len(starts): + raise ValueError("Call stop before checking value!") + time = 0 + for start, end in zip(starts, ends): + time += start.elapsed_time(end) + self.running_times[k] += time * 1e-3 + self.n[k] += len(starts) + self.start_events = collections.defaultdict(list) + self.end_events = collections.defaultdict(list) + + +# Used to measure the time taken for multiple events +class Timer: + def __init__(self, keys): + self.keys = keys + self.n = {} + self.running_time = {} + self.total_time = {} + self.reset() + + def start(self, key): + self.running_time[key] = time.time() + return self + + def stop(self, key): + self.total_time[key] = time.time() - self.running_time[key] + self.n[key] += 1 + self.running_time[key] = None + return self + + def reset(self): + for k in self.keys: + self.total_time[k] = 0 + self.running_time[k] = None + self.n[k] = 0 + return self + + def value(self): + vals = {} + for k in self.keys: + if self.n[k] == 0: + raise ValueError("Trying to divide by zero in TimeMeter") + else: + vals[k] = self.total_time[k] / self.n[k] + return vals diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py new file mode 100644 index 0000000000000000000000000000000000000000..6a6585e8b6901a059445ff54ca20ea87751bbb11 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tacotron2/waveglow_denoiser.py @@ -0,0 +1,40 @@ +# import sys +# sys.path.append('tacotron2') +import torch +from .layers import STFT + + +class Denoiser(torch.nn.Module): + """ Removes model bias from audio produced with waveglow """ + + def __init__(self, waveglow, filter_length=1024, n_overlap=4, + win_length=1024, mode='zeros'): + super(Denoiser, self).__init__() + self.stft = STFT(filter_length=filter_length, + hop_length=int(filter_length/n_overlap), + win_length=win_length).cuda() + if mode == 'zeros': + mel_input = torch.zeros( + (1, 80, 88), + dtype=waveglow.upsample.weight.dtype, + device=waveglow.upsample.weight.device) + elif mode == 'normal': + mel_input = torch.randn( + (1, 80, 88), + dtype=waveglow.upsample.weight.dtype, + device=waveglow.upsample.weight.device) + else: + raise Exception("Mode {} if not supported".format(mode)) + + with torch.no_grad(): + bias_audio = waveglow.infer(mel_input, sigma=0.0).float() + bias_spec, _ = self.stft.transform(bias_audio) + + self.register_buffer('bias_spec', bias_spec[:, :, 0][:, :, None]) + + def forward(self, audio, strength=0.1): + audio_spec, audio_angles = self.stft.transform(audio.cuda().float()) + audio_spec_denoised = audio_spec - self.bias_spec * strength + audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0) + audio_denoised = self.stft.inverse(audio_spec_denoised, audio_angles) + return audio_denoised diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py b/fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d2b04c0feea0f7656f426e63090a6d4fb5fb8ed8 --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/tts_data.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import numpy as np +from examples.textless_nlp.gslm.unit2speech.tacotron2.text import ( + EOS_TOK, + SOS_TOK, + code_to_sequence, + text_to_sequence, +) +from examples.textless_nlp.gslm.unit2speech.tacotron2.utils import ( + load_code_dict, +) + + +class TacotronInputDataset: + def __init__(self, hparams, append_str=""): + self.is_text = getattr(hparams, "text_or_code", "text") == "text" + if not self.is_text: + self.code_dict = load_code_dict( + hparams.code_dict, hparams.add_sos, hparams.add_eos + ) + self.code_key = hparams.code_key + self.add_sos = hparams.add_sos + self.add_eos = hparams.add_eos + self.collapse_code = hparams.collapse_code + self.append_str = append_str + + def process_code(self, inp_str): + inp_toks = inp_str.split() + if self.add_sos: + inp_toks = [SOS_TOK] + inp_toks + if self.add_eos: + inp_toks = inp_toks + [EOS_TOK] + return code_to_sequence(inp_toks, self.code_dict, self.collapse_code) + + def process_text(self, inp_str): + return text_to_sequence(inp_str, ["english_cleaners"]) + + def get_tensor(self, inp_str): + # uid, txt, inp_str = self._get_data(idx) + inp_str = inp_str + self.append_str + if self.is_text: + inp_toks = self.process_text(inp_str) + else: + inp_toks = self.process_code(inp_str) + return torch.from_numpy(np.array(inp_toks)).long() + + def __len__(self): + return len(self.data) diff --git a/fairseq/examples/textless_nlp/gslm/unit2speech/utils.py b/fairseq/examples/textless_nlp/gslm/unit2speech/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7aced08d38301b98b19e2df7d19f1c61150107bc --- /dev/null +++ b/fairseq/examples/textless_nlp/gslm/unit2speech/utils.py @@ -0,0 +1,55 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from examples.textless_nlp.gslm.unit2speech.tacotron2.model import Tacotron2 +from examples.textless_nlp.gslm.unit2speech.tacotron2.waveglow_denoiser import ( + Denoiser, +) + + +def load_quantized_audio_from_file(file_path): + base_fname_batch, quantized_units_batch = [], [] + with open(file_path) as f: + for line in f: + base_fname, quantized_units_str = line.rstrip().split("|") + quantized_units = [int(q) for q in quantized_units_str.split(" ")] + base_fname_batch.append(base_fname) + quantized_units_batch.append(quantized_units) + return base_fname_batch, quantized_units_batch + + +def synthesize_audio(model, waveglow, denoiser, inp, lab=None, strength=0.0): + assert inp.size(0) == 1 + inp = inp.cuda() + if lab is not None: + lab = torch.LongTensor(1).cuda().fill_(lab) + + with torch.no_grad(): + _, mel, _, ali, has_eos = model.inference(inp, lab, ret_has_eos=True) + aud = waveglow.infer(mel, sigma=0.666) + aud_dn = denoiser(aud, strength=strength).squeeze(1) + return mel, aud, aud_dn, has_eos + + +def load_tacotron(tacotron_model_path, max_decoder_steps): + ckpt_dict = torch.load(tacotron_model_path) + hparams = ckpt_dict["hparams"] + hparams.max_decoder_steps = max_decoder_steps + sr = hparams.sampling_rate + model = Tacotron2(hparams) + model.load_state_dict(ckpt_dict["model_dict"]) + model = model.cuda().eval().half() + return model, sr, hparams + + +def load_waveglow(waveglow_path): + waveglow = torch.load(waveglow_path)["model"] + waveglow = waveglow.cuda().eval().half() + for k in waveglow.convinv: + k.float() + denoiser = Denoiser(waveglow) + return waveglow, denoiser diff --git a/fairseq/examples/textless_nlp/pgslm/README.md b/fairseq/examples/textless_nlp/pgslm/README.md new file mode 100644 index 0000000000000000000000000000000000000000..596467fe8283a7f1514261151f728f7bd233dbd7 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/README.md @@ -0,0 +1,318 @@ +# Text-Free Prosody-Aware Generative Spoken Language Modeling + +This folder contains code and recipes to reproduce results reported in a paper _Text-Free Prosody-Aware Generative Spoken Language Modeling_, +Eugene Kharitonov*, Ann Lee*, Adam Polyak, Yossi Adi, Jade Copet, Kushal Lakhotia, Tu-Anh Nguyen, Morgane Rivière, Abdelrahman Mohamed, Emmanuel Dupoux, Wei-Ning Hsu, 2021. arxiv/2109.03264 [[arxiv]](https://arxiv.org/abs/2109.03264). + +`*` denotes equal contribution. + +You can find demo samples [[here]](https://speechbot.github.io/pgslm/index.html). + +
+ If you find this code useful, please consider citing our work using this bibtex + +``` + @misc{Kharitonov2021, + title={Text-Free Prosody-Aware Generative Spoken Language Modeling}, + author={Eugene Kharitonov and Ann Lee and Adam Polyak and Yossi Adi and Jade Copet and Kushal Lakhotia and Tu-Anh Nguyen and Morgane Rivière and Abdelrahman Mohamed and Emmanuel Dupoux and Wei-Ning Hsu}, + year={2021}, + eprint={2109.03264}, + archivePrefix={arXiv}, + primaryClass={cs.CL} +} +``` +
+ + +## Additional requirements +Three packages are required in addition to fairseq, they are installable with pip: +```bash +pip install AMFM-decompy SoundFile scipy sklearn torchaudio npy-append-array +``` + +## Data preprocessing + +### Prepare unit pseudo-text transcriptions of the audio +To get unit trascripts of the speech data we rely on the preprocessing steps of [GSLM](https://github.com/pytorch/fairseq/tree/main/examples/textless_nlp/gslm/speech2unit/) work. + +Firstly, we will need to prepare manifest files for the dataset we want to preprocess +``` +mkdir manifests/ +python examples/wav2vec/wav2vec_manifest.py --valid-percent=0.0 $DATA_PATH --dest=manifests/train/ +``` +Next, we need a pre-trained HuBERT-base-ls960 model [[download]](https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt) and a corresponding kmeans-100 quantizer [[download]](https://dl.fbaipublicfiles.com/textless_nlp/gslm/hubert/km100/km.bin). Having those we can quantize the dataset: +``` +python examples/textless_nlp/gslm/speech2unit/clustering/quantize_with_kmeans.py \ + --feature_type hubert \ + --kmeans_model_path km.bin \ + --acoustic_model_path hubert_base_ls960.pt \ + --layer 6 \ + --manifest_path manifests/train/train.tsv \ + --out_quantized_file_path manifests/train/units +``` + +Finally, by running +``` +python examples/textless_nlp/pgslm/scripts/join_units_manifest.py --manifest=manifests/train/train.tsv --units=manifests/train/units --output=train.txt +``` +We will get the training data description `train.txt` in the format that pGSLM expects. The above steps have to be repeated for +dev/test sets. Importantly, we rely on an assumption that the directories are structured as in LibriSpeech, i.e. the file paths follow the +`//.wav` format. + +### Preprocess data for pGSLM +The very first step is to obtain the F0 quantization bins. +Assume the vocoder training manifest is `vocoder_train.txt` (in pGSLM data format prepared with the same process above). +We prepare the quantized F0 from the vocoder training data by running +```sh +bash examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh \ + vocoder_train.txt 32 # we use 32 bins in the paper +``` +- ``: sampling rate of the audio files in the manifest +- ``: where to output the output files +- ``: prefix of the output files + +The script will generate +- `.f0_stat.pt`: the speaker-level F0 statistics, which can be used in vocoder training +- `_mean_norm_log_f0_bin.th`: the quantized F0, which should be used in `prepare_data.sh` below + +**Note:** See "Pre-trained models" for the pre-computed speaker-level F0 statistics and quantized F0 bins. We suggest using the pre-computed statistics for the data preparation below in order to take advantage of the pre-trained vocoder for waveform generation. + +Next prepare the pGSLM data. +Assume train/valid/test manifests are `{train,valid,test}.txt`. +Here is an example of how to preprocess data: + +```sh +bash examples/textless_nlp/pgslm/scripts/prepare_data.sh \ + train.txt valid.txt test.txt \ + /_mean_norm_log_f0_bin.th +``` +- ``: discrete unit vocabulary size (we used a kmeans quantizer with the number of units equal to 100 in the example above) +- ``: downsampling rate relative to the waveform (e.g., 320 for HuBERT units) +- ``: sampling rate of the audio files in the manifest +- ``: where to output the preprocessed files + +This will create the dataset json config used for the next section at +`/data_config.json`. + +Note that the example script uses only one thread to compute F0, which can take +_very long_ for preprocessing large datasets. It is suggested to distribute +jobs over multiple nodes/processes with `--nshards=x` and `--rank=z` (where z is +in [1, x]) in `preprocess_f0.py`, and set `--nshards_list=x` in +`prepare_data.py` correspondingly to collect sharded F0 data. + +Now, everything is ready for training a model. + +## Training Multi-Stream Transformer Unit Language Model (MS-TLM) + +Below is an example command that trains Multi-Stream Transformer Language Model (MS-TLM) on a prepared dataset: +```bash +DATASET=data_config.json + +fairseq-train $DATASET \ + --task=speech_unit_modeling \ + --arch="transformer_ulm_tiny" \ + --criterion=speech_unit_lm_criterion \ + --share-decoder-input-output-embed \ + --dropout=0.1 \ + --attention-dropout=0.1 \ + --optimizer="adam" \ + --adam-betas="(0.9, 0.98)" \ + --clip-norm=1.0 \ + --lr=0.0005 \ + --lr-scheduler="inverse_sqrt" \ + --warmup-updates=4000 \ + --warmup-init-lr=1e-07 \ + --tokens-per-sample=3072 \ + --max-tokens=3072 \ + --update-freq=4 \ + --max-epoch=70 \ + --num-workers=0 \ + --skip-invalid-size-inputs-valid-test \ + --loss-weights="1.0;0.5;0.0" \ + --ignore-f0-input \ + --checkpoint-activations \ + --fp16 \ + --max-target-positions=4096 \ + --stream-shifts="1,1" \ + --log-f0 --normalize-f0-mean --interpolate-f0 \ + --ignore-unused-valid-subsets \ + --discrete-duration --discrete-f0 +``` + +Some of the important parameters that are specific to MS-TLM: + * `arch`: specifies the Transformer architecture used. Supported options are: + * `transformer_ulm_tiny` - a tiny model that can be used for debugging; it has 2 layers, 1 attention head, FFN and embedding dimensions of 64, + * `transformer_ulm` - a base model with 6 layers, 8 heads, embedding dimension 512, and FFN dimensionality of 2048, + * `transformer_ulm_big` - the largest model we experiment with in the paper: 12-layer/16 heads, 1024/4096 embedding and FFN dimensions; + * `loss-weights`: this parameter sets importance weights (must be non-negative) for the components of the loss that correspond to unit, duration, and F0 streams. To turn off a component of the loss, its weight has to be set to 0. For instance, to predict only unit stream the parameter should be set to "1;0;0"; + * `stream-shifts`: specifies relative shifts of the two prosodic streams w.r.t. the unit stream (duration and F0, respectively). No shift corresponds to "0,0"; + * `ignore-duration-input`/`ignore-f0-input`: setting these flags would zero-out correpsonding input streams; + * `max-token-duration`: duration values would be max-capped by the specified value; + * `discrete-duration`/`discrete-f0`: whether duration and F0 streams should be quantized; + * `log_f0`, `normalize-f0-mean`, `normalize-f0-std`, `interpolate-f0`: configure how F0 stream is treated. `log_f0` sets up modelling in the log-space, `normalize-f0-mean`/`normalize-f0-std` control per-speaker normalization, and `interpolate-f0` enables F0 interpolation for unvoiced regions where F0 was set to 0, + * `mask-dur-prob`, `mask-f0-prob`, `mask-dur-seg-prob`, `mask-f0-seg-prob`, `mask-unit-seg-prob`, `mask-unit-seg-leng`: this family of parameters sets the probababilities of masking individual steps and spans on each stream as well as lengths of the maked spans. + + +## Pre-trained models +### MS-TLM +Below you can find checkpoints for four best-performing models from the paper (IDs 9..12 in Table 1). These models are trained on Hubert-100 transcripts of the LibriLight-6K dataset. They have the prosody streams shifted by 1 w.r.t. the unit stream. All models predict all three streams (units, duration, and F0), but two +of them only have unit steam in their input. + +| | Continuous prosody | Quantized prosody | +|-------------------|--------------------|-------------------| +| No prosody input | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/continuous_no_prosody_shift_1_1.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/discrete_no_prosody_shift_1_1.pt) | +| Has prosody input | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/continuous_prosody_shift_1_1.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/ulm_checkpoints/discrete_prosody_shift_1_1.pt)| + +The optimal per-stream sampling temperatures/scaling parameters that we have identified for these models, in the (`T-token, T-duration, T-f0`) format: + +| | Continuous prosody | Quantized prosody | +|-------------------|--------------------|-------------------| +| No prosody input | 0.7, 0.125, 0.0003125| 0.7, 0.25, 0.5 | +| Has prosody input | 0.7, 0.125, 0.00125 | 0.7, 0.25, 0.7 | + +## Vocoder +| Units | Prosody | F0 stats | Checkpoint | Config | +|-------------------|---------|--------------|------------|--------| +| HuBERT-base-ls960, kmeans-100 | [[Quantized 32 bins]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_seg_bin.th) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/f0_stats.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/naive_quant_32_norm_log_seg_hubert/checkpoint.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/naive_quant_32_norm_log_seg_hubert/config.json) | +| HuBERT-base-ls960, kmeans-100 | Continuous | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/f0_stats.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_hubert/checkpoint.pt) | [[download]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_hubert/config.json) | + + +## Evaluating a trained model +Evaluation is done with the `eval/cont_metrics.py` scripts. As described in the paper, there are several metrics used. + +**Teacher-forced metrics** +```bash +SET=valid +CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt +DATA=data_config.json + +python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \ + --metric=teacher_force_everything \ + --path=$CHECKPOINT_PATH \ + --batch-size=16 \ + --fp16 \ + --seed=111 \ + --eval-subset=$SET \ + --f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody +``` +(Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{'token_loss': 1.408..., 'duration_loss': 0.5424..., 'f0_loss': 0.0474...}` on LibriSpeech dev-clean). + +The parameters `--f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody` are specific for quantized-prosody models. They signal that the prosody streams must be decoded into the continuous domain before calculating correlation. It is the same `*_mean_norm_log_f0_bin.th` file as we prepared before. +The `mean_norm_log_f0_seg_bin.th` file we used with the pre-trained models can be downloaded [[here]](https://dl.fbaipublicfiles.com/textless_nlp/pgslm/vocoder/blizzard2013/mean_norm_log_f0_seg_bin.th). + + +**Consistency (aka Correlation) metrics** + +The following command estimates correlation between mean values of the F0 stream in the prompt and in the generated continuation (unit and duration steams are fixed). + +```bash +T_F0=0.7 +EXPLOSION=20 +SET=test +CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt +DATA=data_config.json + +python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \ + --prefix-length=150 \ + --metric=correlation \ + --path=$CHECKPOINT_PATH \ + --batch-size=16 \ + --fp16 \ + --seed=111 \ + --teacher-force-tokens \ + --teacher-force-duration \ + --min-length=300 \ + --batch-explosion-rate=$EXPLOSION \ + --T-f0=$T_F0 \ + --eval-subset=$SET \ + --f0-discretization-bounds=mean_norm_log_f0_seg_bin.th \ + --dequantize-prosody --n-workers=8 +``` +(Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{...'F0 corr': 0.315 ..}` on LibriSpeech test-clean). + + * By using flags `--teacher-force-tokens, --teacher-force-duration, --teacher-force-f0` one can calculate correlations along each stream while having other two streams fixed to ground-truth values (or freeze all three streams to get ground-truth correlation values); + * The parameters `T-f0`, `T-duration`, and `T-token` specify per-stream temperatures and, in the case of continuous-valued prosody, scaling parameter of the corresponding Laplace distribution (setting a temperature to 0 will enforce greedy sampling); + * `min-length` filters out sequences that are shorter then 300 duration units (i.e. 6s in the case of Hubert units); + * `prefix-length` specifies that we want to use first 150 duration units are prompt (i.e. 3s in the case of Hubert units) + + +**Correctness (aka Continuation) and Expressiveness (aka Std) metrics** + +By running the following command, we can get minMAE and Std for the log-F0 stream for the model with quantized prosody. +```bash +DATA=data_config.json +EXPLOSION=20 +SET=test +CHECKPOINT_PATH=discrete_prosody_shift_1_1.pt +T_F0=0.7 + +python examples/textless_nlp/pgslm/eval/cont_metrics.py $DATA \ + --prefix-length=150 \ + --metric=continuation \ + --path=$CHECKPOINT_PATH \ + --batch-size=16 \ + --fp16 \ + --seed=111 \ + --batch-explosion-rate=$EXPLOSION \ + --teacher-force-tokens \ + --teacher-force-duration \ + --T-f0=$T_F0 \ + --eval-subset=$SET \ + --f0-discretization-bounds=mean_norm_log_f0_seg_bin.th --dequantize-prosody +``` +(Using this command, our provided `discrete_prosody_shift_1_1.pt` checkpoint should produce `{...'F0 MAE': 0.0772, 'F0 Std': 0.1489...}` on LibriSpeech test-clean). + +Again, by setting `--teacher-force-tokens, --teacher-force-duration, --teacher-force-f0` we can calculate Token BLEU for the token stream (when `--teacher-force-duration` & `--teacher-force-f0` are on) and per-stream min MAE for each prosody stream individually. + +Finally, `cont_metrics.py` allows to specify the number of workers (e.g., `n-workers=8`) which allows to speed up the computation by spreading multiple worker processes +over the available GPUs. + +**Cont Word BLEU** + +We used the code and the evaluation protocol of [(Lakhotia et al., 2021)](https://arxiv.org/abs/2102.01192). + +## Sampling from a trained model + +To get (prompted or not) samples from a trained model it is enough to run `sample.py`: +```bash +CHECKPOINT_PATH=checkpoints/checkpoint_best.pt +DATASET=examples/textless_nlp/pgslm/repro/dataset/data_config.json +python examples/textless_nlp/pgslm/sample/sample.py $DATASET \ + --output=$SAMPLES \ + --path=$CHECKPOINT_PATH \ + --sampling \ + --T-token=0.7 \ + --T-duration=0.25 \ + --T-f0=0.7 \ + --max-length=500 \ + --prefix-length=150 \ + --subset=valid \ + --seed=1 \ + --match-duration \ + --code-type=hubert \ + --batch-explosion-rate=2 +``` + +Some useful parameters: + * `T-token`, `T-duration`, `T-f0` specify sampling temperature for the three streams. Setting a temperature to `0` switches sample to the greedy (argmax) one; + * `prefix-length`: length of the prompt, measured in timesteps (e.g. for Hubert (CPC) each timestep is 20 (10) ms); + * `subset`: which subset of the dataset to use as prompts (can be `train`, `valid`, `test`); + * `teacher-force-tokens`, `teacher-force-duration`, `teacher-force-f0`: if set, at each autoregressive step, ground-truth values replace the produced one; + * `short-curcuit`: replace sampling by ground-truth inputs; + * `match-duration`: forces the produced sample to have the same duration (in time), as the entire sequence (beyond the prompt if there is any); + * `batch-explosion-rate`: number of samples per prompt; + * `f0-discretization-bounds`: path to a file with quantization boundaries. If it is set, F0 values are de-quantized back to the continuous domain + (the model must be a quanized one); + * `max-length` sets the maximal number of segment steps to be produced. + +Note that `sample.py` automatically uses all available GPUs, to avoid that please use environment variable `CUDA_VISIBLE_DEVICES`. + +## Vocoding samples +To generate audios for output from `sample.py` (`$IN_FILE`): +```bash +python examples/textless_nlp/pgslm/generate_waveform.py \ + --in-file=$IN_FILE \ + --vocoder=$VODOER \ + --vocoder-cfg=$VOCODER_CFG \ + --results-path=$RESULTS_PATH +``` +See "Pre-trained model" for `$VOCODER` and `VOCODER_CFG`. diff --git a/fairseq/examples/textless_nlp/pgslm/data_utils.py b/fairseq/examples/textless_nlp/pgslm/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2033697b372ce615b184589a8b3c4212798f5073 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/data_utils.py @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch + +from tqdm import tqdm + + +class Stat: + def __init__(self, keep_raw=False): + self.x = 0.0 + self.x2 = 0.0 + self.z = 0.0 # z = logx + self.z2 = 0.0 + self.n = 0.0 + self.u = 0.0 + self.keep_raw = keep_raw + self.raw = [] + + def update(self, new_x): + new_z = new_x.log() + + self.x += new_x.sum() + self.x2 += (new_x**2).sum() + self.z += new_z.sum() + self.z2 += (new_z**2).sum() + self.n += len(new_x) + self.u += 1 + + if self.keep_raw: + self.raw.append(new_x) + + @property + def mean(self): + return self.x / self.n + + @property + def std(self): + return (self.x2 / self.n - self.mean**2) ** 0.5 + + @property + def mean_log(self): + return self.z / self.n + + @property + def std_log(self): + return (self.z2 / self.n - self.mean_log**2) ** 0.5 + + @property + def n_frms(self): + return self.n + + @property + def n_utts(self): + return self.u + + @property + def raw_data(self): + assert self.keep_raw, "does not support storing raw data!" + return torch.cat(self.raw) + + +class F0Stat(Stat): + def update(self, new_x): + # assume unvoiced frames are 0 and consider only voiced frames + if new_x is not None: + super().update(new_x[new_x != 0]) + + +def dump_speaker_f0_stat(speaker_to_f0_stat, out_prefix): + path = f"{out_prefix}.f0_stat.pt" + assert not os.path.exists(path) + + d = { + speaker: { + "f0_mean": speaker_to_f0_stat[speaker].mean, + "f0_std": speaker_to_f0_stat[speaker].std, + "logf0_mean": speaker_to_f0_stat[speaker].mean_log, + "logf0_std": speaker_to_f0_stat[speaker].std_log, + } + for speaker in speaker_to_f0_stat + } + torch.save(d, path) + + return d + + +def load_audio_path(path): + audio_paths = [] + with open(path) as f: + for line in f.readlines(): + sample = eval(line.strip()) + audio_paths.append(sample["audio"]) + + return audio_paths + + +def load_f0(f0_dir, nshards): + path_to_f0 = {} + for rank in tqdm(range(1, nshards + 1), desc=f"load f0"): + f0_shard_path = f"{f0_dir}/f0_{rank}_{nshards}.pt" + shard_path_to_f0 = torch.load(f0_shard_path) + path_to_f0.update(shard_path_to_f0) + return path_to_f0 diff --git a/fairseq/examples/textless_nlp/pgslm/eval/__init__.py b/fairseq/examples/textless_nlp/pgslm/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e028c26b9abf533d16fa07d00d942ca07bbba61 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/eval/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py b/fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..e98abadde3f4c22adfe6ef3a5d9bd68c4e62d408 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/eval/cont_metrics.py @@ -0,0 +1,730 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import scipy + +import torch +import torch.multiprocessing as mp +from fairseq import checkpoint_utils, options +from fairseq.data.codedataset import CodeDataset, ExpressiveCodeDataConfig +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from torch.utils.data import DataLoader, DistributedSampler +from fairseq.utils import move_to_cuda +from fairseq import utils +from fairseq.criterions.speech_ulm_criterion import nll_loss, mae_loss + +import time +from types import SimpleNamespace + +import sys, pathlib + +sys.path.append(str(pathlib.Path(__file__).parent.parent.resolve())) + +from naive_decoder import Naive_F0_Decoder +from inference_dataset import InferenceDataset, explode_batch +from sample.sample import do_sampling, TemperatureDecoder, FilterNamesDataset + +try: + from nltk.translate.bleu_score import sentence_bleu +except ImportError: + print("Please install nltk: `pip install --user -U nltk`") + raise + + +@torch.no_grad() +def teacher_force_everything( + args, dataset, model, criterion, tgt_dict, rank, world_size +): + prefix = args.prefix_length + + f0_decoder = None + if args.dequantize_prosody: + assert dataset.discrete_f0 + print("Reporting MAE for a discrete model") + f0_decoder = Naive_F0_Decoder( + args.f0_discretization_bounds, dataset.config.f0_vq_n_units + ).cuda() + + dataset = InferenceDataset( + dataset, + prefix=args.prefix_length, + only_prefix=False, + filter_short=True, + presort_by_length=True, + ) + sampler = ( + None + if world_size == 1 + else DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + ) + dataloader = DataLoader( + dataset, + args.batch_size, + shuffle=False, + collate_fn=dataset.collater, + sampler=sampler, + ) + + total_token_loss, total_duration_loss, total_f0_loss, total_tokens = ( + 0.0, + 0.0, + 0.0, + 0.0, + ) + + i = 0 + for batch in dataloader: + i += 1 + batch = move_to_cuda(batch) + output = model(**batch["net_input"]) + + tokens, durations, f0 = output["token"], output["duration"], output["f0"] + durations, f0 = durations.squeeze(), f0.squeeze() + + token_loss = nll_loss( + tokens[:, prefix - 1 :], + batch["target"][:, prefix - 1 :].contiguous(), + batch["mask"][:, prefix - 1 :].contiguous(), + reduce=True, + ) + + if args.dequantize_prosody: + durations = durations.argmax(dim=-1) + duration_loss = mae_loss( + durations[:, prefix - 1 :].contiguous().float(), + batch["dur_target"][:, prefix - 1 :].contiguous().float(), + batch["dur_mask"][:, prefix - 1 :].contiguous(), + reduce=True, + ) + else: + duration_loss = criterion.dur_loss_fn( + durations[:, prefix - 1 :].contiguous(), + batch["dur_target"][:, prefix - 1 :].contiguous(), + batch["dur_mask"][:, prefix - 1 :].contiguous(), + reduce=True, + ) + + if f0_decoder: + f0 = f0.argmax(dim=-1) + f0 = f0_decoder(f0).squeeze(-1) + + f0_target = batch["raw_f0"] + f0_loss = mae_loss( + f0[:, prefix - 1 :].contiguous(), + f0_target[:, prefix - 1 :].contiguous(), + batch["f0_mask"][:, prefix - 1 :].contiguous(), + reduce=True, + ) + else: + f0_loss = criterion.f0_loss_fn( + f0[:, prefix - 1 :].contiguous(), + batch["f0_target"][:, prefix - 1 :].contiguous(), + batch["f0_mask"][:, prefix - 1 :].contiguous(), + reduce=True, + ) + + n_tokens = (~batch["dur_mask"])[:, prefix - 1 :].sum() + + total_token_loss += token_loss.item() + total_duration_loss += duration_loss.item() + total_f0_loss += f0_loss.item() + + total_tokens += n_tokens.item() + if args.debug and i > 5: + break + + values = torch.tensor([total_token_loss, total_duration_loss, total_f0_loss]) + normalizers = torch.tensor([total_tokens for _ in range(3)]) + + return values, normalizers + + +def get_bleu(produced_tokens, target_tokens, tgt_dict): + assert target_tokens.ndim == 1 + assert produced_tokens.size(1) == target_tokens.size(0) + + # we can have padding due to shifted channels + shift = 0 + for token in reversed(target_tokens.cpu().tolist()): + if token in [tgt_dict.pad(), tgt_dict.eos()]: + shift += 1 + else: + break + target_tokens = target_tokens[:-shift] + produced_tokens = produced_tokens[:, :-shift] + + string_target = tgt_dict.string(target_tokens).split() + string_candidates = [ + tgt_dict.string(produced_tokens[i, :]).split() + for i in range(produced_tokens.size(0)) + ] + + bleu3 = sentence_bleu( + references=string_candidates, + hypothesis=string_target, + weights=(1.0 / 3, 1.0 / 3, 1.0 / 3), + ) + return bleu3 + + +@torch.no_grad() +def continuation(args, dataset, model, criterion, tgt_dict, rank, world_size): + is_discrete_duration = dataset.discrete_dur + is_discrete_f0 = dataset.discrete_f0 + + f0_decoder = None + if args.dequantize_prosody: + assert dataset.discrete_f0 + print("Reporting MAE F0 for a discrete model") + f0_decoder = Naive_F0_Decoder( + args.f0_discretization_bounds, dataset.config.f0_vq_n_units + ).cuda() + + dataset = InferenceDataset( + dataset, args.prefix_length, filter_short=True, presort_by_length=True + ) + sampler = ( + None + if world_size == 1 + else DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + ) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collater, + sampler=sampler, + ) + + Ts = args.T_token, args.T_duration, args.T_f0 + decoder = TemperatureDecoder( + Ts, discrete_dur=is_discrete_duration, discrete_f0=is_discrete_f0 + ) + + running_stats = SimpleNamespace( + token_bleu=0.0, + duration_nll=0.0, + duration_mae=0.0, + f0_nll=0.0, + f0_mae=0.0, + n_tokens=0.0, + n_sentences=0.0, + f0_sum=0.0, + f0_sum_sq=0.0, + dur_sum=0.0, + dur_sum_sq=0.0, + ) + + for i, batch in enumerate(dataloader): + batch = explode_batch(batch, args.batch_explosion_rate) + bsz = batch["target"].size(0) + + batch = move_to_cuda(batch) + prefix = batch["prefix"][0] + + max_length_to_unroll = batch["target"].size(1) + prefix_length = batch["net_input"]["src_tokens"].size(1) + steps = max_length_to_unroll - prefix_length + 1 + + assert steps > 0 + produced_tokens, produced_durations, produced_f0, outputs = do_sampling( + model, + batch, + tgt_dict.eos(), + decoder, + autoregressive_steps=steps, + teacher_force_tokens=args.teacher_force_tokens, + teacher_force_duration=args.teacher_force_duration, + teacher_force_f0=args.teacher_force_f0, + ) + + if args.teacher_force_tokens: + assert (produced_tokens[:, 1:] == batch["target"]).all() + if args.teacher_force_duration: + assert (produced_durations[:, 1:] == batch["dur_target"]).all() + if args.teacher_force_f0: + assert (produced_f0[:, 1:] == batch["f0_target"]).all() + + dur_target = batch["dur_target"][:, prefix - 1 :].contiguous() + f0_target = batch["f0_target"][:, prefix - 1 :].contiguous() + + f0_mask = batch["f0_mask"][:, prefix - 1 :].contiguous() + dur_mask = batch["dur_mask"][:, prefix - 1 :].contiguous() + + duration_mae = mae_loss( + produced_durations[:, prefix:].float(), + dur_target.float(), + dur_mask, + reduce=False, + ) + min_duration_mae = duration_mae.view(bsz, -1).sum(dim=-1).min(dim=0)[0] + running_stats.duration_mae += min_duration_mae + + running_stats.dur_sum += ( + produced_durations[:, prefix:].float() * (~dur_mask) + ).sum() / args.batch_explosion_rate + running_stats.dur_sum_sq += ( + produced_durations[:, prefix:].float() * (~dur_mask) + ).pow(2.0).sum() / args.batch_explosion_rate + + if is_discrete_duration: + duration_loss = criterion.dur_loss_fn( + torch.stack([x[1] for x in outputs], dim=1), + dur_target, + dur_mask, + reduce=False, + ) + min_duration_loss = duration_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0] + running_stats.duration_nll += min_duration_loss + + if f0_decoder: # can only exist for discrete F0 models + decoded_produced_f0 = f0_decoder(produced_f0[:, prefix:]) + decoded_f0_target = batch["raw_f0"][:, prefix - 1 :].contiguous() + + if produced_f0.ndim == 3: + decoded_produced_f0 = decoded_produced_f0.squeeze(2) + decoded_f0_target = decoded_f0_target.squeeze(2) + + f0_mae = mae_loss( + decoded_produced_f0, decoded_f0_target, f0_mask, reduce=False + ) + f0_mae = f0_mae.view(bsz, -1).sum(dim=-1).min(dim=0)[0] + running_stats.f0_mae += f0_mae + + f0_loss = criterion.f0_loss_fn( + torch.stack([x[2] for x in outputs], dim=1), + f0_target.long(), + f0_mask, + reduce=False, + ) + f0_loss = f0_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0] + running_stats.f0_nll += f0_loss + + running_stats.f0_sum += ( + decoded_produced_f0 * (~f0_mask) + ).sum() / args.batch_explosion_rate + running_stats.f0_sum_sq += (decoded_produced_f0 * (~f0_mask)).pow( + 2.0 + ).sum() / args.batch_explosion_rate + + else: + assert not is_discrete_duration + + f0_loss = mae_loss( + produced_f0[:, prefix:], f0_target, f0_mask, reduce=False + ) + f0_loss = f0_loss.view(bsz, -1).sum(dim=-1).min(dim=0)[0] + running_stats.f0_mae += f0_loss + + running_stats.f0_sum += ( + produced_f0[:, prefix:].sum() / args.batch_explosion_rate + ) + running_stats.f0_sum_sq += ( + produced_f0[:, prefix:].pow(2.0).sum() / args.batch_explosion_rate + ) + + running_stats.n_tokens += (~dur_mask)[0, ...].sum() + + token_loss = get_bleu( + produced_tokens[:, prefix:], batch["target"][0, prefix - 1 :], tgt_dict + ) + running_stats.token_bleu += token_loss + running_stats.n_sentences += 1 + + if args.debug: + break + + values = torch.tensor( + [ + running_stats.token_bleu, + running_stats.duration_nll, + running_stats.duration_mae, + running_stats.f0_nll, + running_stats.f0_mae, + running_stats.f0_sum, + running_stats.f0_sum_sq, + running_stats.dur_sum, + running_stats.dur_sum_sq, + ] + ) + normalizers = torch.tensor( + [running_stats.n_sentences] + [running_stats.n_tokens] * 8 + ) + + return values, normalizers + + +@torch.no_grad() +def correlation(args, dataset, model, criterion, tgt_dict, rank, world_size): + is_discrete_duration = dataset.discrete_dur + is_discrete_f0 = dataset.discrete_f0 + + f0_decoder = None + if is_discrete_f0: + assert dataset.discrete_f0 + f0_decoder = Naive_F0_Decoder( + args.f0_discretization_bounds, dataset.config.f0_vq_n_units + ).cuda() + + if is_discrete_f0: + assert f0_decoder # correlation on tokens is meaningless + + dataset = InferenceDataset( + dataset, + args.prefix_length, + filter_short=True, + presort_by_length=True, + min_length=args.min_length, + ) + sampler = ( + None + if world_size == 1 + else DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=False + ) + ) + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=False, + collate_fn=dataset.collater, + sampler=sampler, + ) + + Ts = args.T_token, args.T_duration, args.T_f0 + decoder = TemperatureDecoder( + Ts, discrete_dur=is_discrete_duration, discrete_f0=is_discrete_f0 + ) + + mean_dur_prefix, mean_dur_cont = [], [] + mean_f0_prefix, mean_f0_cont = [], [] + + for batch in dataloader: + batch = explode_batch(batch, args.batch_explosion_rate) + batch = move_to_cuda(batch) + + assert len(batch["prefix"]) == 1 + + if args.teacher_force_tokens: + autoregressive_steps = batch["target"].size(1) - args.prefix_length - 1 + else: + autoregressive_steps = args.max_length - args.prefix_length # + max_shift? + + if args.copy_target: + produced_durations, produced_f0 = batch["dur_target"], batch["f0_target"] + else: + _, produced_durations, produced_f0, outputs = do_sampling( + model, + batch, + tgt_dict.eos(), + decoder, + autoregressive_steps=autoregressive_steps, + teacher_force_tokens=args.teacher_force_tokens, + teacher_force_duration=args.teacher_force_duration, + teacher_force_f0=args.teacher_force_f0, + ) + + # first tokens actually correspond to BOS + produced_durations = produced_durations[:, 1:] + produced_f0 = produced_f0[:, 1:] + + dur_target = batch["dur_target"] + if is_discrete_duration: + produced_durations = produced_durations.float() + dur_target = dur_target.float() + + if is_discrete_f0: + produced_f0 = f0_decoder(produced_f0).squeeze(-1) + f0_target = batch["raw_f0"] + else: + f0_target = batch["f0_target"] + + # prefix values + prefix = batch["prefix"][0] + dur_prefix_mean = dur_target[:, :prefix].sum(dim=-1) / ( + (~batch["dur_mask"][:, :prefix]).sum(dim=-1) + ) + + non_voiced = f0_target[:, :prefix] == 0.0 + f0_mask = batch["f0_mask"][:, :prefix].logical_or(non_voiced) + f0_prefix_mean = f0_target[:, :prefix].sum(dim=-1) / ((~f0_mask).sum(dim=-1)) + + # continuation values + dur_cont_mean = produced_durations[:, prefix:].sum(dim=-1) / ( + (~batch["dur_mask"][:, prefix:]).sum(dim=-1) + ) + + non_voiced = produced_f0[:, prefix:] == 0.0 + f0_mask = non_voiced + f0_cont_mean = produced_f0[:, prefix:].sum(dim=-1) / ((~f0_mask).sum(dim=-1)) + + assert not f0_cont_mean.isnan().any() + + mean_dur_prefix.append(dur_prefix_mean.cpu()) + mean_dur_cont.append(dur_cont_mean.cpu()) + + mean_f0_prefix.append(f0_prefix_mean.cpu()) + mean_f0_cont.append(f0_cont_mean.cpu()) + + if args.debug and len(mean_dur_prefix) > 10: + break + + mean_dur_prefix, mean_dur_cont = torch.cat(mean_dur_prefix), torch.cat( + mean_dur_cont + ) + mean_f0_prefix, mean_f0_cont = torch.cat(mean_f0_prefix), torch.cat(mean_f0_cont) + + return mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont + + +def main(rank, world_size, args): + start = time.time() + + if world_size > 1: + torch.distributed.init_process_group( + backend="gloo", init_method="env://", world_size=world_size, rank=rank + ) + torch.cuda.set_device(rank % torch.cuda.device_count()) + + raw_args = args + + args = convert_namespace_to_omegaconf(args) + if args.common.seed is not None: + np.random.seed(args.common.seed) + utils.set_torch_seed(args.common.seed) + + models, model_args, task = checkpoint_utils.load_model_ensemble_and_task( + [raw_args.path], arg_overrides={"data": args.task.data} + ) + + tgt_dict = task.target_dictionary + + for model in models: + model.prepare_for_inference_(args) + model.cuda().eval() + if raw_args.fp16: + model = model.half() + model = models[0] + + config = ExpressiveCodeDataConfig(args.task.data) + + dataset = CodeDataset( + manifest=config.manifests[raw_args.eval_subset], + dictionary=task.source_dictionary, + dur_dictionary=task.source_duration_dictionary, + f0_dictionary=task.source_f0_dictionary, + config=config, + discrete_dur=task.cfg.discrete_duration, + discrete_f0=task.cfg.discrete_f0, + log_f0=task.cfg.log_f0, + normalize_f0_mean=task.cfg.normalize_f0_mean, + normalize_f0_std=task.cfg.normalize_f0_std, + interpolate_f0=task.cfg.interpolate_f0, + shifts=task.cfg.stream_shifts, + return_filename=True, + strip_filename=False, + return_continuous_f0=raw_args.dequantize_prosody, + ) + + if raw_args.filter_names: + dataset = FilterNamesDataset(dataset, raw_args.filter_names) + + criterion = task.build_criterion(model_args.criterion) + + name2metric = { + "continuation": continuation, + "teacher_force_everything": teacher_force_everything, + "correlation": correlation, + } + + name2keys = { + "continuation": ( + "Token BLEU3", + "Duration NLL", + "Duration MAE", + "F0 NLL", + "F0 MAE", + "F0 sum", + "F0 sum_sq", + "Dur sum", + "Dur sum_sq", + ), + "teacher_force_everything": ("token_loss", "duration_loss", "f0_loss"), + "correlation": ("Duration corr", "F0 corr"), + } + metric_name = raw_args.metric + + metric = name2metric[metric_name] + results = metric(raw_args, dataset, model, criterion, tgt_dict, rank, world_size) + + values = None + + if metric_name not in [ + "correlation", + ]: + values, normalizers = results + values = maybe_aggregate_normalize(values, normalizers, world_size) + elif metric_name == "correlation": + values = maybe_aggregate_correlations(results, world_size) + else: + assert False + + assert values is not None + summary = dict(zip(name2keys[raw_args.metric], values.tolist())) + if metric_name == "continuation": + summary["F0 Std"] = np.sqrt(-summary["F0 sum"] ** 2 + summary["F0 sum_sq"]) + summary["Dur Std"] = np.sqrt(-summary["Dur sum"] ** 2 + summary["Dur sum_sq"]) + del summary["F0 sum"] + del summary["F0 sum_sq"] + del summary["Dur sum"] + del summary["Dur sum_sq"] + + summary["metric"] = metric_name + + if rank == 0: + print(summary) + if raw_args.wandb: + wandb_results(summary, raw_args) + print("# finished in ", time.time() - start, "seconds") + + +def wandb_results(summary, raw_args): + import wandb + + run = wandb.init( + project=raw_args.wandb_project_name, tags=raw_args.wandb_tags.split(",") + ) + run.config.metric = raw_args.metric + run.config.model = raw_args.path + run.config.data = raw_args.data + + if raw_args.wandb_run_name: + run.name = raw_args.wandb_run_name + run.save() + + wandb.log(summary) + wandb.finish() + + +def maybe_aggregate_normalize(values, normalizers, world_size): + if world_size > 1: + torch.distributed.barrier() + + torch.distributed.all_reduce_multigpu([values]) + torch.distributed.all_reduce_multigpu([normalizers]) + + return values / normalizers + + +def maybe_aggregate_correlations(results, world_size): + if world_size > 1: + output = [None for _ in range(world_size)] + torch.distributed.all_gather_object(output, results) + mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont = [ + torch.cat([x[i] for x in output]) for i in range(4) + ] + else: + mean_dur_prefix, mean_dur_cont, mean_f0_prefix, mean_f0_cont = results + + corr_dur = scipy.stats.pearsonr(mean_dur_prefix.numpy(), mean_dur_cont.numpy())[0] + corr_f0 = scipy.stats.pearsonr(mean_f0_prefix.numpy(), mean_f0_cont.numpy())[0] + values = torch.tensor([corr_dur, corr_f0]) + + return values + + +def cli_main(): + parser = options.get_interactive_generation_parser() + parser.add_argument( + "--prefix-length", + type=int, + default=1, + help="Prompt prefix length (including )", + ) + parser.add_argument( + "--duration-scale", + type=float, + default=1, + help="Multiply durations by the given scaler", + ) + parser.add_argument( + "--debug", action="store_true", help="Process only the first batch" + ) + parser.add_argument("--n_hypotheses", type=int, default=1) + parser.add_argument("--filter-names", type=str, default=None) + parser.add_argument( + "--max-length", type=int, default=200, help="Maximal produced length" + ) + + parser.add_argument("--teacher-force-tokens", action="store_true", default=False) + parser.add_argument("--teacher-force-duration", action="store_true", default=False) + parser.add_argument("--teacher-force-f0", action="store_true", default=False) + + parser.add_argument("--copy-target", action="store_true", default=False) + parser.add_argument("--min-length", type=int, default=None) + parser.add_argument("--f0-discretization-bounds", type=str, default=None) + parser.add_argument("--dequantize-prosody", action="store_true") + parser.add_argument("--batch-explosion-rate", type=int, default=1) + + parser.add_argument( + "--metric", + choices=["continuation", "teacher_force_everything", "correlation"], + required=True, + ) + + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--wandb-project-name", type=str, default="eslm") + parser.add_argument("--wandb-tags", type=str, default="") + parser.add_argument("--wandb-run-name", type=str, default="") + + parser.add_argument("--T-token", type=float, default=1.0) + parser.add_argument("--T-duration", type=float, default=1.0) + parser.add_argument("--T-f0", type=float, default=1.0) + + parser.add_argument("--n-workers", type=int, default=1) + + parser.add_argument( + "--eval-subset", type=str, default="valid", choices=["valid", "test"] + ) + + args = options.parse_args_and_arch(parser) + + assert ( + args.prefix_length >= 1 + ), "Prefix length includes bos token , hence the minimum is 1." + assert args.temperature >= 0.0, "T must be non-negative!" + + if args.dequantize_prosody: + assert args.f0_discretization_bounds + + world_size = args.n_workers or torch.cuda.device_count() + if world_size > 1: + import random + + mp.set_start_method("spawn", force=True) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(random.randint(10_000, 50_000)) + + mp.spawn( + main, + nprocs=world_size, + args=( + world_size, + args, + ), + join=True, + ) + else: + main(rank=0, world_size=world_size, args=args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/textless_nlp/pgslm/generate_waveform.py b/fairseq/examples/textless_nlp/pgslm/generate_waveform.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f348bb9b3e3ee998a5f8e2dcafe01a45476c98 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/generate_waveform.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import argparse +import json +import logging +from pathlib import Path +import soundfile as sf +import torch + +from tqdm import tqdm + +from fairseq import utils +from fairseq.models.text_to_speech.vocoder import CodeHiFiGANVocoder + + +logging.basicConfig() +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def dump_result(args, data, sample_id, pred_wav): + assert "audio" in data or args.results_path is not None + if args.results_path: + fname = Path(data["audio"]).name if "audio" in data else f"{sample_id}_pred.wav" + out_file = Path(args.results_path) / fname + + sf.write( + out_file.as_posix(), + pred_wav.detach().cpu().numpy(), + args.sample_rate, + ) + + +def load_data(in_file): + with open(in_file) as f: + data = [ast.literal_eval(line.strip()) for line in f] + + return data + + +def get_f0_upsample_ratio(code_hop_size, f_hop_size): + ratio = (code_hop_size // 160) // (f_hop_size // 256) * 2 + return ratio + + +def main(args): + logger.info(args) + + use_cuda = torch.cuda.is_available() and not args.cpu + + with open(args.vocoder_cfg) as f: + vocoder_cfg = json.load(f) + vocoder = CodeHiFiGANVocoder(args.vocoder, vocoder_cfg) + if use_cuda: + vocoder = vocoder.cuda() + + data = load_data(args.in_file) + + if args.results_path: + Path(args.results_path).mkdir(exist_ok=True, parents=True) + + for i, d in tqdm(enumerate(data), total=len(data)): + code_key = "cpc_km100" if "cpc_km100" in d else "hubert" + code = list(map(int, d[code_key].split())) + + x = { + "code": torch.LongTensor(code).view(1, -1), + "f0": torch.Tensor(d["f0"]).view(1, -1), + } + + f0_up_ratio = get_f0_upsample_ratio( + vocoder_cfg["code_hop_size"], vocoder_cfg["hop_size"] + ) + if f0_up_ratio > 1: + bsz, cond_length = x["f0"].size() + x["f0"] = x["f0"].unsqueeze(2).repeat(1, 1, f0_up_ratio).view(bsz, -1) + + x = utils.move_to_cuda(x) if use_cuda else x + wav = vocoder(x) + dump_result(args, d, i, wav) + + +def cli_main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-file", + type=str, + required=True, + help="Input file following the same format of the output from sample.py ('f0' and 'cpc_km100/hubert' are required fields)", + ) + parser.add_argument( + "--vocoder", type=str, required=True, help="path to the vocoder" + ) + parser.add_argument( + "--vocoder-cfg", + type=str, + required=True, + help="path to the vocoder config", + ) + parser.add_argument("--sample-rate", type=int, default=16_000) + parser.add_argument( + "--results-path", + type=str, + default=None, + help="Output directory. If not set, the audios will be stored following the 'audio' field specified in the input file.", + ) + parser.add_argument("--cpu", action="store_true", help="run on CPU") + + args = parser.parse_args() + + main(args) + + +if __name__ == "__main__": + cli_main() diff --git a/fairseq/examples/textless_nlp/pgslm/inference_dataset.py b/fairseq/examples/textless_nlp/pgslm/inference_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7cfa5f5437f17a665070f0c539c11a92e30855 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/inference_dataset.py @@ -0,0 +1,103 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + + +class InferenceDataset: + def __init__( + self, + dataset, + prefix, + only_prefix=True, + presort_by_length=True, + filter_short=False, + min_length=None, + ): + self.dataset = dataset + self.collater = self.dataset.collater + self.prefix = prefix + self.only_prefix = only_prefix + self.filter_short = filter_short + + self.remapping = list(range(len(self.dataset))) + if min_length: + assert min_length >= prefix + 1 + + length_thr = prefix + 1 if not min_length else min_length + + if filter_short: + self.remapping = list( + filter( + lambda i: self.dataset[i]["dur_source"].sum() > length_thr, + self.remapping, + ) + ) + print( + f"# the initial dataset of {len(self.dataset)} examples became {len(self.remapping)} after filtering" + f" examples shorter than {length_thr} (in duration units)" + ) + + if presort_by_length: + lengths = {index: dataset.size(index) for index in self.remapping} + self.remapping.sort(key=lambda i: lengths[i]) + + @property + def pads(self): + return self.dataset.pads + + def __len__(self): + return len(self.remapping) + + def original_size(self, k): + k = self.remapping[k] + return self.dataset.size(k) + + def __getitem__(self, k): + k = self.remapping[k] + channels = self.dataset[k] + + if self.prefix and self.only_prefix: + dur_channel = channels["dur_source"] + assert dur_channel.sum() >= self.prefix + + token_times = dur_channel.cumsum(dim=-1) + cut_after = torch.searchsorted(token_times, torch.tensor(self.prefix)) + + r = {} + for channel_name, value in channels.items(): + if isinstance(value, torch.Tensor) and "source" in channel_name: + # if self.filter_short: assert value.size(0) >= self.prefix + r[channel_name] = value[: cut_after + 1] + else: + r[channel_name] = value + + r["prefix"] = cut_after + 1 + else: + r = channels + + return r + + +def explode_batch(batch, times): + if times == 1: + return batch + + new_batch = {} + + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + assert value.size(0) == 1 + new_batch[key] = torch.cat([value] * times) + elif key in ["ntokens", "nsentences"]: + new_batch[key] = value * times + elif key in ["prefix", "filename"]: + new_batch[key] = value + elif key == "net_input": + new_batch[key] = explode_batch(value, times) + else: + assert False, key + return new_batch diff --git a/fairseq/examples/textless_nlp/pgslm/naive_decoder.py b/fairseq/examples/textless_nlp/pgslm/naive_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5132889792b7b2d073037a6b7f3003927d6a802a --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/naive_decoder.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import warnings + + +class Naive_F0_Decoder(torch.nn.Module): + def __init__(self, bounds_path, n_units=32): + super().__init__() + + bounds = torch.load(bounds_path) + bounds = torch.from_numpy(bounds[n_units]) + assert bounds.ndim == 1 + + pad = torch.tensor([-5.0, -5.0]) # bos, eos, pad are in the dictionary + centers = torch.cat( + [bounds[0:1], 0.5 * (bounds[1:] + bounds[:-1]), bounds[-1:], pad[:]] + ) + + self.embedding = torch.nn.Embedding.from_pretrained( + centers.unsqueeze(-1), freeze=True + ) + self.max_n = self.embedding.weight.numel() + + def forward(self, discrete_f0: torch.Tensor): + in_bounds = (0 <= discrete_f0).all() and (discrete_f0 < self.max_n).all() + if not in_bounds: + warnings.warn( + f"F0 contains some weird outputs: discrete_f0.max().item()={discrete_f0.max().item()} discrete_f0.min().item()={discrete_f0.min().item()}; " + f"while we have embeddings for {self.max_n} values. " + "Assuming this is a no-prosody model -- but be careful!" + ) + + mask = discrete_f0 >= self.max_n + discrete_f0 = discrete_f0.masked_fill(mask, self.max_n - 1) + + return self.embedding(discrete_f0).squeeze(-1) diff --git a/fairseq/examples/textless_nlp/pgslm/prepare_dataset.py b/fairseq/examples/textless_nlp/pgslm/prepare_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3d5edaa58fc701f614ee8494ad1f3cb2ad56bda6 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/prepare_dataset.py @@ -0,0 +1,143 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from multiprocessing import Pool + +import os +from collections import defaultdict +from itertools import starmap + +import torch +from npy_append_array import NpyAppendArray +from tqdm import tqdm + +from data_utils import dump_speaker_f0_stat, F0Stat, load_f0 +from fairseq.data.codedataset import ( + ExpressiveCodeDataConfig, + parse_manifest, + F0_FRAME_SPACE, + align_f0_to_durations, +) +from fairseq.tasks.speech_ulm_task import UnitDictionary + + +def load_meta(meta_path, split): + config = ExpressiveCodeDataConfig(meta_path) + manifest_path = config.manifests[split] + dictionary = UnitDictionary(n_units=config.n_units) + audio_paths, codes, durs, speakers = parse_manifest(manifest_path, dictionary) + return config, audio_paths, codes, durs, speakers + + +def _align_f0(f0, dur, ratio, frm_tol=5): + if f0 is None: + seg_f0 = torch.zeros_like(dur, dtype=torch.float) + else: + seg_f0 = align_f0_to_durations(f0, dur, ratio, tol=frm_tol * ratio) + return seg_f0.numpy() # try a hacky stuff + + +def align_f0(path_to_f0, audio_paths, durs, ratio, mp=False): + chunk_size = 2000 + num_procs = 40 + iterable = ((path_to_f0[p], d, ratio) for p, d in zip(audio_paths, durs)) + + seg_f0s = [] + if mp: + with Pool(num_procs) as pool: + iterator = tqdm( + pool.istarmap(_align_f0, iterable, chunk_size), + desc="align f0", + total=len(durs), + ) + for seg_f0 in iterator: + seg_f0s.append(torch.from_numpy(seg_f0).float()) + else: + iterator = tqdm(starmap(_align_f0, iterable), desc="align f0", total=len(durs)) + for seg_f0 in iterator: + seg_f0s.append(torch.from_numpy(seg_f0).float()) + + return seg_f0s + + +def prepare_seg_data(config, audio_paths, codes, durs, speakers, path_to_f0): + ratio = config.code_hop_size / (config.sampling_rate * F0_FRAME_SPACE) + seg_f0s = align_f0(path_to_f0, audio_paths, durs, ratio) + data = { + "codes": codes, + "duration": durs, + "f0": seg_f0s, + "speaker": speakers, + "path": audio_paths, + } + return data + + +def dump_seg_data(data, out_prefix): + key_targs = { + "codes": f"{out_prefix}.code.npy", + "duration": f"{out_prefix}.dur.npy", + "f0": f"{out_prefix}.f0.npy", + } + for key, targ in key_targs.items(): + assert not os.path.exists(targ) + npaa = NpyAppendArray(targ) + for utt_data in tqdm(data[key], desc=f"dumping {key}"): + npaa.append(utt_data.numpy()) + + assert not os.path.exists(f"{out_prefix}.path.txt") + with open(f"{out_prefix}.path.txt", "w") as f: + for x in data["path"]: + f.write(f"{str(x)}\n") + + assert not os.path.exists(f"{out_prefix}.leng.txt") + with open(f"{out_prefix}.leng.txt", "w") as f: + for x in data["codes"]: + f.write(f"{len(x)}\n") + + assert not os.path.exists(f"{out_prefix}.speaker.txt") + with open(f"{out_prefix}.speaker.txt", "w") as f: + for x in data["speaker"]: + f.write(f"{str(x)}\n") + + print(f"wrote to files with prefix {out_prefix}") + + +def main(meta_path, f0_dir, splits, nshards_list): + speaker_to_stat = defaultdict(F0Stat) + if len(nshards_list) == 1: + nshards_list = nshards_list * len(splits) + else: + assert len(nshards_list) == len(splits) + + for split, nshards in zip(splits, nshards_list): + config, audio_paths, codes, durs, speakers = load_meta(meta_path, split) + path_to_f0 = load_f0(f"{f0_dir}/{split}", nshards) + + # segment-level data + data = prepare_seg_data(config, audio_paths, codes, durs, speakers, path_to_f0) + dump_seg_data(data, config.manifests[split]) + + # speaker f0 + for audio_path, speaker in tqdm(zip(audio_paths, speakers)): + f0 = path_to_f0[audio_path] + speaker_to_stat[speaker].update(f0) + dump_speaker_f0_stat(speaker_to_stat, config.manifests[split]) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("meta_path") + parser.add_argument("f0_dir", help="out_dir from preprocess_f0") + parser.add_argument("--splits", nargs="+", default=["train", "valid"]) + parser.add_argument( + "--nshards_list", type=int, nargs="+", default=[20], help="number of f0 shards" + ) + args = parser.parse_args() + print(args) + + main(**vars(args)) diff --git a/fairseq/examples/textless_nlp/pgslm/preprocess_f0.py b/fairseq/examples/textless_nlp/pgslm/preprocess_f0.py new file mode 100644 index 0000000000000000000000000000000000000000..afe899cb85f2fafe78e86cacba17f1d376c64394 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/preprocess_f0.py @@ -0,0 +1,65 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +from tqdm import tqdm +from data_utils import load_audio_path +from fairseq.data.codedataset import get_f0_by_filename + + +def process_one(path, sr): + """ + Args: + path: audio file path + sr: sampling rate + """ + try: + # YAAPT throws errors in some rare cases + f0 = get_f0_by_filename(path, sr) + except Exception as e: + print( + f"WARNING: error when processing {path}. set f0 to zero. original error message:\n{e}" + ) + f0 = None + return f0 + + +def main(file_path, out_dir, nshards, rank, sampling_rate): + # load data + audio_paths = load_audio_path(file_path) + + # shard + assert nshards <= len(audio_paths) and nshards > 0 + shard_size = len(audio_paths) / nshards + s = int(round((rank - 1) * shard_size)) + e = int(round(rank * shard_size)) + audio_paths = audio_paths[s:e] + + # process + path_to_f0 = {} + for i, audio_path in enumerate(tqdm(audio_paths)): + f0 = process_one(audio_path, sampling_rate) + path_to_f0[audio_path] = f0 + print(f"finished processing {len(path_to_f0)} utterances ({s}-{e})") + + f0_path = f"{out_dir}/f0_{rank}_{nshards}.pt" + os.makedirs(out_dir, exist_ok=True) + torch.save(path_to_f0, f0_path) + print(f"saved to {f0_path}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("file_path") + parser.add_argument("out_dir") + parser.add_argument("--nshards", type=int, default=20) + parser.add_argument("--rank", type=int, default=1) + parser.add_argument("--sampling_rate", type=int, default=16000) + args = parser.parse_args() + + main(**vars(args)) diff --git a/fairseq/examples/textless_nlp/pgslm/quantize_f0.py b/fairseq/examples/textless_nlp/pgslm/quantize_f0.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e3df2fe25f768a80f6db2d05d6b8eef436938d --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/quantize_f0.py @@ -0,0 +1,94 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +from functools import partial + +import numpy as np +import torch +from tqdm import tqdm + +from data_utils import dump_speaker_f0_stat, F0Stat, load_audio_path, load_f0 + + +def load_speaker(path): + speakers = [] + with open(path) as f: + for line in f.readlines(): + sample = eval(line.strip()) + assert "speaker" in sample + speakers.append(sample["speaker"]) + return speakers + + +def quantize_f0(speaker_to_f0, f0_stats, nbins, normalize, log): + f0_all = [] + for speaker, f0 in speaker_to_f0.items(): + f0 = f0.raw_data + if log: + f0 = f0.log() + mean = f0_stats[speaker]["logf0_mean"] if log else f0_stats[speaker]["f0_mean"] + std = f0_stats[speaker]["logf0_std"] if log else f0_stats[speaker]["f0_std"] + if normalize == "mean": + f0 = f0 - mean + elif normalize == "meanstd": + f0 = (f0 - mean) / std + f0_all.extend(f0.tolist()) + + hist, bin_x = np.histogram(f0_all, 100000) + cum_hist = np.cumsum(hist) / len(f0_all) * 100 + + f0_bin = {} + for num_bin in nbins: + bin_offset = [] + bin_size = 100 / num_bin + threshold = bin_size + for i in range(num_bin - 1): + index = (np.abs(cum_hist - threshold)).argmin() + bin_offset.append(bin_x[index]) + threshold += bin_size + f0_bin[num_bin] = np.array(bin_offset) + + return f0_bin + + +def main(file_path, f0_dir, out_dir, out_prefix, nbins, nshards, normalize, log): + audio_paths = load_audio_path(file_path) + path_to_f0 = load_f0(f0_dir, nshards) + + speakers = load_speaker(file_path) + speaker_to_f0 = defaultdict(partial(F0Stat, True)) + + # speaker f0 stats + for audio_path, speaker in tqdm(zip(audio_paths, speakers)): + f0 = path_to_f0[audio_path] + speaker_to_f0[speaker].update(f0) + f0_stats = dump_speaker_f0_stat(speaker_to_f0, f"{out_dir}/{out_prefix}") + + # quantize + f0_bin = quantize_f0(speaker_to_f0, f0_stats, nbins, normalize, log) + log_suffix = "_log" if log else "" + f0_bin_out_file = f"{out_dir}/{out_prefix}_{normalize}_norm{log_suffix}_f0_bin.th" + torch.save(f0_bin, f0_bin_out_file) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("file_path") + parser.add_argument("f0_dir", help="out_dir from preprocess_f0") + parser.add_argument("out_dir") + parser.add_argument("out_prefix") + parser.add_argument("--nbins", nargs="+", type=int, default=[32]) + parser.add_argument("--nshards", type=int, default=20, help="number of f0 shards") + parser.add_argument( + "--normalize", type=str, choices=["meanstd", "mean", "none"], default="mean" + ) + parser.add_argument("--log", action="store_true") + args = parser.parse_args() + print(args) + + main(**vars(args)) diff --git a/fairseq/examples/textless_nlp/pgslm/sample/__init__.py b/fairseq/examples/textless_nlp/pgslm/sample/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0e028c26b9abf533d16fa07d00d942ca07bbba61 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/sample/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/fairseq/examples/textless_nlp/pgslm/scripts/prepare_data.sh b/fairseq/examples/textless_nlp/pgslm/scripts/prepare_data.sh new file mode 100644 index 0000000000000000000000000000000000000000..ec892e59a4c6aa416e7871e89f443cd046294980 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/scripts/prepare_data.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +set -eu + +train_json=$1 +valid_json=$2 +test_json=$3 +n_units=$4 +hop_size=$5 +sr=$6 +f0_quantizer=$7 +out_dir=$8 + +meta_path="$out_dir/data_config.json" +f0_dir="$out_dir/f0" + +mkdir -p $out_dir +ln -sf $train_json $out_dir/train.txt +ln -sf $valid_json $out_dir/valid.txt +ln -sf $test_json $out_dir/test.txt + +cat <$meta_path +{ + "manifests": { + "train": "$out_dir/train.txt", + "valid": "$out_dir/valid.txt", + "test": "$out_dir/test.txt" + }, + "n_units": $n_units, + "code_hop_size": $hop_size, + "sampling_rate": $sr, + "multispkr": "parent_parent_name", + + "f0_vq_type": "naive", + "f0_vq_naive_quantizer": { + "log_mean_norm": "$f0_quantizer" + }, + "f0_vq_n_units": 32 +} +EOF + +for split in train valid test; do + python examples/textless_nlp/pgslm/preprocess_f0.py \ + $out_dir/$split.txt $f0_dir/$split --nshards=1 --rank=1 --sampling_rate=$sr + + #NSHARDS=16 + #seq 1 $NSHARDS | parallel -j $NSHARDS python examples/textless_nlp/pgslm/preprocess_f0.py \ + # $out_dir/$split.txt $f0_dir/$split --nshards=$NSHARDS --sampling_rate=$sr --rank +done + +# Please make sure that the number of shards (--nshards_list) is consistent across commands +python examples/textless_nlp/pgslm/prepare_dataset.py \ + $meta_path $f0_dir --splits test valid train --nshards_list 1 diff --git a/fairseq/examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh b/fairseq/examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh new file mode 100644 index 0000000000000000000000000000000000000000..3a285a39bc2bea3240b0f0e38ee71fa872fcadf5 --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/scripts/prepare_f0_quantization.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +set -eu + +train_json=$1 +sr=$2 +nbins=$3 +out_dir=$4 +out_prefix=$5 + +f0_dir="$out_dir/f0" + +python examples/textless_nlp/pgslm/preprocess_f0.py \ + $train_json $f0_dir/${out_prefix}_f0_quant --nshards 1 --rank 1 --sampling_rate $sr + +# NB: one can use parallel here: +# NSHARDS=16 +# +#seq 1 $NSHARDS | parallel -j $NSHARDS python examples/textless_nlp/pgslm/preprocess_f0.py \ +# $train_json $f0_dir/${out_prefix}_f0_quant --nshards $NSHARDS --sampling_rate $sr --rank + +python examples/textless_nlp/pgslm/quantize_f0.py \ + $train_json $f0_dir/${out_prefix}_f0_quant $out_dir $out_prefix --nbins $nbins --nshards 1 --normalize mean --log diff --git a/fairseq/examples/textless_nlp/pgslm/truncated_laplace.py b/fairseq/examples/textless_nlp/pgslm/truncated_laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..089f8a8cfce3cbc56ca5d3a7c9bc368184ba780d --- /dev/null +++ b/fairseq/examples/textless_nlp/pgslm/truncated_laplace.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import warnings + + +def truncated_laplace(mean, T, truncate_by_zero=False): + """Generating a sample from a Laplace distribution, possible left-truncated at zero. + A bit of explanation here https://stats.stackexchange.com/a/357598 . + """ + assert isinstance(mean, torch.Tensor) + + if not truncate_by_zero: + percentile = 0.0 + else: + if not (mean >= 0.0).all(): + warnings.warn(f"means are supposed to be non-negative, but got {mean}") + mean = torch.clamp_min(mean, 0.0) + + lower_bound = mean.new_tensor([0.0]) + percentile = 0.5 + 0.5 * torch.sign(lower_bound - mean) * ( + 1.0 - torch.exp(-1.0 / T * torch.abs(mean - lower_bound)) + ) + + p = torch.empty_like(mean).uniform_() * (1.0 - percentile) + percentile + return mean - T * torch.sign(p - 0.5) * torch.log(1 - 2 * torch.abs(p - 0.5)) diff --git a/fairseq/examples/textless_nlp/speech-resynth/README.md b/fairseq/examples/textless_nlp/speech-resynth/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a099682cdbdb991e2c37f1a00387e0b8fc8f121c --- /dev/null +++ b/fairseq/examples/textless_nlp/speech-resynth/README.md @@ -0,0 +1,28 @@ + +# Speech Resynthesis from Discrete Disentangled Self-Supervised Representations +Landing page with usfull resources for the [Speech Resynthesis from Discrete Disentangled Self-Supervised Representations](https://arxiv.org/abs/2104.00355) paper. + +

+ +__Abstract__: We propose using self-supervised discrete representations for the task of speech resynthesis. To generate disentangled representation, we separately extract low-bitrate representations for speech content, prosodic information, and speaker identity. This allows to synthesize speech in a controllable manner. We analyze various state-of-the-art, self-supervised representation learning methods and shed light on the advantages of each method while considering reconstruction quality and disentanglement properties. Specifically, we evaluate the F0 reconstruction, speaker identification performance (for both resynthesis and voice conversion), recordings' intelligibility, and overall quality using subjective human evaluation. Lastly, we demonstrate how these representations can be used for an ultra-lightweight speech codec. Using the obtained representations, we can get to a rate of 365 bits per second while providing better speech quality than the baseline methods. + + +## Quick Links +- [Paper](https://arxiv.org/pdf/2104.00355.pdf) +- [Samples](https://speechbot.github.io/resynthesis/index.html) +- [Code](https://github.com/facebookresearch/speech-resynthesis) + +The codebase for the [Speech Resynthesis from Discrete Disentangled Self-Supervised Representations](https://arxiv.org/abs/2104.00355) paper can be found under the following [repository](https://github.com/facebookresearch/speech-resynthesis). + + +## Citation +``` +@inproceedings{polyak21_interspeech, + author={Adam Polyak and Yossi Adi and Jade Copet and + Eugene Kharitonov and Kushal Lakhotia and + Wei-Ning Hsu and Abdelrahman Mohamed and Emmanuel Dupoux}, + title={{Speech Resynthesis from Discrete Disentangled Self-Supervised Representations}}, + year=2021, + booktitle={Proc. Interspeech 2021}, +} +``` diff --git a/fairseq/examples/translation/README.md b/fairseq/examples/translation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2941f5eb8482dab61dca5eca27a71abd7ee5bf5c --- /dev/null +++ b/fairseq/examples/translation/README.md @@ -0,0 +1,301 @@ +# Neural Machine Translation + +This README contains instructions for [using pretrained translation models](#example-usage-torchhub) +as well as [training new models](#training-a-new-model). + +## Pre-trained models + +Model | Description | Dataset | Download +---|---|---|--- +`conv.wmt14.en-fr` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2)
newstest2012/2013:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.ntst1213.tar.bz2) +`conv.wmt14.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT14 English-German](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-de.newstest2014.tar.bz2) +`conv.wmt17.en-de` | Convolutional
([Gehring et al., 2017](https://arxiv.org/abs/1705.03122)) | [WMT17 English-German](http://statmt.org/wmt17/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt17.v2.en-de.fconv-py.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.v2.en-de.newstest2014.tar.bz2) +`transformer.wmt14.en-fr` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2) +`transformer.wmt16.en-de` | Transformer
([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2)
newstest2014:
[download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2) +`transformer.wmt18.en-de` | Transformer
([Edunov et al., 2018](https://arxiv.org/abs/1808.09381))
WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz)
See NOTE in the archive +`transformer.wmt19.en-de` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-German](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-de.joined-dict.ensemble.tar.gz) +`transformer.wmt19.de-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 German-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.de-en.joined-dict.ensemble.tar.gz) +`transformer.wmt19.en-ru` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 English-Russian](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.en-ru.ensemble.tar.gz) +`transformer.wmt19.ru-en` | Transformer
([Ng et al., 2019](https://arxiv.org/abs/1907.06616))
WMT'19 winner | [WMT'19 Russian-English](http://www.statmt.org/wmt19/translation-task.html) | model:
[download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt19.ru-en.ensemble.tar.gz) + +## Example usage (torch.hub) + +We require a few additional Python dependencies for preprocessing: +```bash +pip install fastBPE sacremoses subword_nmt +``` + +Interactive translation via PyTorch Hub: +```python +import torch + +# List available models +torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ] + +# Load a transformer trained on WMT'16 En-De +# Note: WMT'19 models use fastBPE instead of subword_nmt, see instructions below +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', + tokenizer='moses', bpe='subword_nmt') +en2de.eval() # disable dropout + +# The underlying model is available under the *models* attribute +assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel) + +# Move model to GPU for faster translation +en2de.cuda() + +# Translate a sentence +en2de.translate('Hello world!') +# 'Hallo Welt!' + +# Batched translation +en2de.translate(['Hello world!', 'The cat sat on the mat.']) +# ['Hallo Welt!', 'Die Katze saß auf der Matte.'] +``` + +Loading custom models: +```python +from fairseq.models.transformer import TransformerModel +zh2en = TransformerModel.from_pretrained( + '/path/to/checkpoints', + checkpoint_file='checkpoint_best.pt', + data_name_or_path='data-bin/wmt17_zh_en_full', + bpe='subword_nmt', + bpe_codes='data-bin/wmt17_zh_en_full/zh.code' +) +zh2en.translate('你好 世界') +# 'Hello World' +``` + +If you are using a `transformer.wmt19` models, you will need to set the `bpe` +argument to `'fastbpe'` and (optionally) load the 4-model ensemble: +```python +en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt19.en-de', + checkpoint_file='model1.pt:model2.pt:model3.pt:model4.pt', + tokenizer='moses', bpe='fastbpe') +en2de.eval() # disable dropout +``` + +## Example usage (CLI tools) + +Generation with the binarized test sets can be run in batch mode as follows, e.g. for WMT 2014 English-French on a GTX-1080ti: +```bash +mkdir -p data-bin +curl https://dl.fbaipublicfiles.com/fairseq/models/wmt14.v2.en-fr.fconv-py.tar.bz2 | tar xvjf - -C data-bin +curl https://dl.fbaipublicfiles.com/fairseq/data/wmt14.v2.en-fr.newstest2014.tar.bz2 | tar xvjf - -C data-bin +fairseq-generate data-bin/wmt14.en-fr.newstest2014 \ + --path data-bin/wmt14.en-fr.fconv-py/model.pt \ + --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out +# ... +# | Translated 3003 sentences (96311 tokens) in 166.0s (580.04 tokens/s) +# | Generate test with beam=5: BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) + +# Compute BLEU score +grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys +grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref +fairseq-score --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref +# BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) +``` + +## Training a new model + +### IWSLT'14 German to English (Transformer) + +The following instructions can be used to train a Transformer model on the [IWSLT'14 German to English dataset](http://workshop2014.iwslt.org/downloads/proceeding.pdf). + +First download and preprocess the data: +```bash +# Download and prepare the data +cd examples/translation/ +bash prepare-iwslt14.sh +cd ../.. + +# Preprocess/binarize the data +TEXT=examples/translation/iwslt14.tokenized.de-en +fairseq-preprocess --source-lang de --target-lang en \ + --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ + --destdir data-bin/iwslt14.tokenized.de-en \ + --workers 20 +``` + +Next we'll train a Transformer translation model over this data: +```bash +CUDA_VISIBLE_DEVICES=0 fairseq-train \ + data-bin/iwslt14.tokenized.de-en \ + --arch transformer_iwslt_de_en --share-decoder-input-output-embed \ + --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ + --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ + --dropout 0.3 --weight-decay 0.0001 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --max-tokens 4096 \ + --eval-bleu \ + --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \ + --eval-bleu-detok moses \ + --eval-bleu-remove-bpe \ + --eval-bleu-print-samples \ + --best-checkpoint-metric bleu --maximize-best-checkpoint-metric +``` + +Finally we can evaluate our trained model: +```bash +fairseq-generate data-bin/iwslt14.tokenized.de-en \ + --path checkpoints/checkpoint_best.pt \ + --batch-size 128 --beam 5 --remove-bpe +``` + +### WMT'14 English to German (Convolutional) + +The following instructions can be used to train a Convolutional translation model on the WMT English to German dataset. +See the [Scaling NMT README](../scaling_nmt/README.md) for instructions to train a Transformer translation model on this data. + +The WMT English to German dataset can be preprocessed using the `prepare-wmt14en2de.sh` script. +By default it will produce a dataset that was modeled after [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762), but with additional news-commentary-v12 data from WMT'17. + +To use only data available in WMT'14 or to replicate results obtained in the original [Convolutional Sequence to Sequence Learning (Gehring et al., 2017)](https://arxiv.org/abs/1705.03122) paper, please use the `--icml17` option. + +```bash +# Download and prepare the data +cd examples/translation/ +# WMT'17 data: +bash prepare-wmt14en2de.sh +# or to use WMT'14 data: +# bash prepare-wmt14en2de.sh --icml17 +cd ../.. + +# Binarize the dataset +TEXT=examples/translation/wmt17_en_de +fairseq-preprocess \ + --source-lang en --target-lang de \ + --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ + --destdir data-bin/wmt17_en_de --thresholdtgt 0 --thresholdsrc 0 \ + --workers 20 + +# Train the model +mkdir -p checkpoints/fconv_wmt_en_de +fairseq-train \ + data-bin/wmt17_en_de \ + --arch fconv_wmt_en_de \ + --dropout 0.2 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer nag --clip-norm 0.1 \ + --lr 0.5 --lr-scheduler fixed --force-anneal 50 \ + --max-tokens 4000 \ + --save-dir checkpoints/fconv_wmt_en_de + +# Evaluate +fairseq-generate data-bin/wmt17_en_de \ + --path checkpoints/fconv_wmt_en_de/checkpoint_best.pt \ + --beam 5 --remove-bpe +``` + +### WMT'14 English to French +```bash +# Download and prepare the data +cd examples/translation/ +bash prepare-wmt14en2fr.sh +cd ../.. + +# Binarize the dataset +TEXT=examples/translation/wmt14_en_fr +fairseq-preprocess \ + --source-lang en --target-lang fr \ + --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \ + --destdir data-bin/wmt14_en_fr --thresholdtgt 0 --thresholdsrc 0 \ + --workers 60 + +# Train the model +mkdir -p checkpoints/fconv_wmt_en_fr +fairseq-train \ + data-bin/wmt14_en_fr \ + --arch fconv_wmt_en_fr \ + --dropout 0.1 \ + --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ + --optimizer nag --clip-norm 0.1 \ + --lr 0.5 --lr-scheduler fixed --force-anneal 50 \ + --max-tokens 3000 \ + --save-dir checkpoints/fconv_wmt_en_fr + +# Evaluate +fairseq-generate \ + data-bin/fconv_wmt_en_fr \ + --path checkpoints/fconv_wmt_en_fr/checkpoint_best.pt \ + --beam 5 --remove-bpe +``` + +## Multilingual Translation + +We also support training multilingual translation models. In this example we'll +train a multilingual `{de,fr}-en` translation model using the IWSLT'17 datasets. + +Note that we use slightly different preprocessing here than for the IWSLT'14 +En-De data above. In particular we learn a joint BPE code for all three +languages and use fairseq-interactive and sacrebleu for scoring the test set. + +```bash +# First install sacrebleu and sentencepiece +pip install sacrebleu sentencepiece + +# Then download and preprocess the data +cd examples/translation/ +bash prepare-iwslt17-multilingual.sh +cd ../.. + +# Binarize the de-en dataset +TEXT=examples/translation/iwslt17.de_fr.en.bpe16k +fairseq-preprocess --source-lang de --target-lang en \ + --trainpref $TEXT/train.bpe.de-en \ + --validpref $TEXT/valid0.bpe.de-en,$TEXT/valid1.bpe.de-en,$TEXT/valid2.bpe.de-en,$TEXT/valid3.bpe.de-en,$TEXT/valid4.bpe.de-en,$TEXT/valid5.bpe.de-en \ + --destdir data-bin/iwslt17.de_fr.en.bpe16k \ + --workers 10 + +# Binarize the fr-en dataset +# NOTE: it's important to reuse the en dictionary from the previous step +fairseq-preprocess --source-lang fr --target-lang en \ + --trainpref $TEXT/train.bpe.fr-en \ + --validpref $TEXT/valid0.bpe.fr-en,$TEXT/valid1.bpe.fr-en,$TEXT/valid2.bpe.fr-en,$TEXT/valid3.bpe.fr-en,$TEXT/valid4.bpe.fr-en,$TEXT/valid5.bpe.fr-en \ + --tgtdict data-bin/iwslt17.de_fr.en.bpe16k/dict.en.txt \ + --destdir data-bin/iwslt17.de_fr.en.bpe16k \ + --workers 10 + +# Train a multilingual transformer model +# NOTE: the command below assumes 1 GPU, but accumulates gradients from +# 8 fwd/bwd passes to simulate training on 8 GPUs +mkdir -p checkpoints/multilingual_transformer +CUDA_VISIBLE_DEVICES=0 fairseq-train data-bin/iwslt17.de_fr.en.bpe16k/ \ + --max-epoch 50 \ + --ddp-backend=legacy_ddp \ + --task multilingual_translation --lang-pairs de-en,fr-en \ + --arch multilingual_transformer_iwslt_de_en \ + --share-decoders --share-decoder-input-output-embed \ + --optimizer adam --adam-betas '(0.9, 0.98)' \ + --lr 0.0005 --lr-scheduler inverse_sqrt \ + --warmup-updates 4000 --warmup-init-lr '1e-07' \ + --label-smoothing 0.1 --criterion label_smoothed_cross_entropy \ + --dropout 0.3 --weight-decay 0.0001 \ + --save-dir checkpoints/multilingual_transformer \ + --max-tokens 4000 \ + --update-freq 8 + +# Generate and score the test set with sacrebleu +SRC=de +sacrebleu --test-set iwslt17 --language-pair ${SRC}-en --echo src \ + | python scripts/spm_encode.py --model examples/translation/iwslt17.de_fr.en.bpe16k/sentencepiece.bpe.model \ + > iwslt17.test.${SRC}-en.${SRC}.bpe +cat iwslt17.test.${SRC}-en.${SRC}.bpe \ + | fairseq-interactive data-bin/iwslt17.de_fr.en.bpe16k/ \ + --task multilingual_translation --lang-pairs de-en,fr-en \ + --source-lang ${SRC} --target-lang en \ + --path checkpoints/multilingual_transformer/checkpoint_best.pt \ + --buffer-size 2000 --batch-size 128 \ + --beam 5 --remove-bpe=sentencepiece \ + > iwslt17.test.${SRC}-en.en.sys +grep ^H iwslt17.test.${SRC}-en.en.sys | cut -f3 \ + | sacrebleu --test-set iwslt17 --language-pair ${SRC}-en +``` + +##### Argument format during inference + +During inference it is required to specify a single `--source-lang` and +`--target-lang`, which indicates the inference langauge direction. +`--lang-pairs`, `--encoder-langtok`, `--decoder-langtok` have to be set to +the same value as training. diff --git a/fairseq/examples/translation/prepare-iwslt17-multilingual.sh b/fairseq/examples/translation/prepare-iwslt17-multilingual.sh new file mode 100644 index 0000000000000000000000000000000000000000..23be87555322bc03b13e9d95951d88b1a442f97a --- /dev/null +++ b/fairseq/examples/translation/prepare-iwslt17-multilingual.sh @@ -0,0 +1,133 @@ +#!/bin/bash +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +SRCS=( + "de" + "fr" +) +TGT=en + +ROOT=$(dirname "$0") +SCRIPTS=$ROOT/../../scripts +SPM_TRAIN=$SCRIPTS/spm_train.py +SPM_ENCODE=$SCRIPTS/spm_encode.py + +BPESIZE=16384 +ORIG=$ROOT/iwslt17_orig +DATA=$ROOT/iwslt17.de_fr.en.bpe16k +mkdir -p "$ORIG" "$DATA" + +TRAIN_MINLEN=1 # remove sentences with <1 BPE token +TRAIN_MAXLEN=250 # remove sentences with >250 BPE tokens + +URLS=( + "https://wit3.fbk.eu/archive/2017-01-trnted/texts/de/en/de-en.tgz" + "https://wit3.fbk.eu/archive/2017-01-trnted/texts/fr/en/fr-en.tgz" +) +ARCHIVES=( + "de-en.tgz" + "fr-en.tgz" +) +VALID_SETS=( + "IWSLT17.TED.dev2010.de-en IWSLT17.TED.tst2010.de-en IWSLT17.TED.tst2011.de-en IWSLT17.TED.tst2012.de-en IWSLT17.TED.tst2013.de-en IWSLT17.TED.tst2014.de-en IWSLT17.TED.tst2015.de-en" + "IWSLT17.TED.dev2010.fr-en IWSLT17.TED.tst2010.fr-en IWSLT17.TED.tst2011.fr-en IWSLT17.TED.tst2012.fr-en IWSLT17.TED.tst2013.fr-en IWSLT17.TED.tst2014.fr-en IWSLT17.TED.tst2015.fr-en" +) + +# download and extract data +for ((i=0;i<${#URLS[@]};++i)); do + ARCHIVE=$ORIG/${ARCHIVES[i]} + if [ -f "$ARCHIVE" ]; then + echo "$ARCHIVE already exists, skipping download" + else + URL=${URLS[i]} + wget -P "$ORIG" "$URL" + if [ -f "$ARCHIVE" ]; then + echo "$URL successfully downloaded." + else + echo "$URL not successfully downloaded." + exit 1 + fi + fi + FILE=${ARCHIVE: -4} + if [ -e "$FILE" ]; then + echo "$FILE already exists, skipping extraction" + else + tar -C "$ORIG" -xzvf "$ARCHIVE" + fi +done + +echo "pre-processing train data..." +for SRC in "${SRCS[@]}"; do + for LANG in "${SRC}" "${TGT}"; do + cat "$ORIG/${SRC}-${TGT}/train.tags.${SRC}-${TGT}.${LANG}" \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | grep -v '' \ + | sed -e 's///g' \ + | sed -e 's/<\/title>//g' \ + | sed -e 's/<description>//g' \ + | sed -e 's/<\/description>//g' \ + | sed 's/^\s*//g' \ + | sed 's/\s*$//g' \ + > "$DATA/train.${SRC}-${TGT}.${LANG}" + done +done + +echo "pre-processing valid data..." +for ((i=0;i<${#SRCS[@]};++i)); do + SRC=${SRCS[i]} + VALID_SET=(${VALID_SETS[i]}) + for ((j=0;j<${#VALID_SET[@]};++j)); do + FILE=${VALID_SET[j]} + for LANG in "$SRC" "$TGT"; do + grep '<seg id' "$ORIG/${SRC}-${TGT}/${FILE}.${LANG}.xml" \ + | sed -e 's/<seg id="[0-9]*">\s*//g' \ + | sed -e 's/\s*<\/seg>\s*//g' \ + | sed -e "s/\’/\'/g" \ + > "$DATA/valid${j}.${SRC}-${TGT}.${LANG}" + done + done +done + +# learn BPE with sentencepiece +TRAIN_FILES=$(for SRC in "${SRCS[@]}"; do echo $DATA/train.${SRC}-${TGT}.${SRC}; echo $DATA/train.${SRC}-${TGT}.${TGT}; done | tr "\n" ",") +echo "learning joint BPE over ${TRAIN_FILES}..." +python "$SPM_TRAIN" \ + --input=$TRAIN_FILES \ + --model_prefix=$DATA/sentencepiece.bpe \ + --vocab_size=$BPESIZE \ + --character_coverage=1.0 \ + --model_type=bpe + +# encode train/valid +echo "encoding train with learned BPE..." +for SRC in "${SRCS[@]}"; do + python "$SPM_ENCODE" \ + --model "$DATA/sentencepiece.bpe.model" \ + --output_format=piece \ + --inputs $DATA/train.${SRC}-${TGT}.${SRC} $DATA/train.${SRC}-${TGT}.${TGT} \ + --outputs $DATA/train.bpe.${SRC}-${TGT}.${SRC} $DATA/train.bpe.${SRC}-${TGT}.${TGT} \ + --min-len $TRAIN_MINLEN --max-len $TRAIN_MAXLEN +done + +echo "encoding valid with learned BPE..." +for ((i=0;i<${#SRCS[@]};++i)); do + SRC=${SRCS[i]} + VALID_SET=(${VALID_SETS[i]}) + for ((j=0;j<${#VALID_SET[@]};++j)); do + python "$SPM_ENCODE" \ + --model "$DATA/sentencepiece.bpe.model" \ + --output_format=piece \ + --inputs $DATA/valid${j}.${SRC}-${TGT}.${SRC} $DATA/valid${j}.${SRC}-${TGT}.${TGT} \ + --outputs $DATA/valid${j}.bpe.${SRC}-${TGT}.${SRC} $DATA/valid${j}.bpe.${SRC}-${TGT}.${TGT} + done +done diff --git a/fairseq/examples/translation/prepare-wmt14en2de.sh b/fairseq/examples/translation/prepare-wmt14en2de.sh new file mode 100644 index 0000000000000000000000000000000000000000..6702c88b568c9e680b525593ff0c9fb0a474825d --- /dev/null +++ b/fairseq/examples/translation/prepare-wmt14en2de.sh @@ -0,0 +1,142 @@ +#!/bin/bash +# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh + +echo 'Cloning Moses github repository (for tokenization scripts)...' +git clone https://github.com/moses-smt/mosesdecoder.git + +echo 'Cloning Subword NMT repository (for BPE pre-processing)...' +git clone https://github.com/rsennrich/subword-nmt.git + +SCRIPTS=mosesdecoder/scripts +TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl +CLEAN=$SCRIPTS/training/clean-corpus-n.perl +NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl +REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl +BPEROOT=subword-nmt/subword_nmt +BPE_TOKENS=40000 + +URLS=( + "http://statmt.org/wmt13/training-parallel-europarl-v7.tgz" + "http://statmt.org/wmt13/training-parallel-commoncrawl.tgz" + "http://data.statmt.org/wmt17/translation-task/training-parallel-nc-v12.tgz" + "http://data.statmt.org/wmt17/translation-task/dev.tgz" + "http://statmt.org/wmt14/test-full.tgz" +) +FILES=( + "training-parallel-europarl-v7.tgz" + "training-parallel-commoncrawl.tgz" + "training-parallel-nc-v12.tgz" + "dev.tgz" + "test-full.tgz" +) +CORPORA=( + "training/europarl-v7.de-en" + "commoncrawl.de-en" + "training/news-commentary-v12.de-en" +) + +# This will make the dataset compatible to the one used in "Convolutional Sequence to Sequence Learning" +# https://arxiv.org/abs/1705.03122 +if [ "$1" == "--icml17" ]; then + URLS[2]="http://statmt.org/wmt14/training-parallel-nc-v9.tgz" + FILES[2]="training-parallel-nc-v9.tgz" + CORPORA[2]="training/news-commentary-v9.de-en" + OUTDIR=wmt14_en_de +else + OUTDIR=wmt17_en_de +fi + +if [ ! -d "$SCRIPTS" ]; then + echo "Please set SCRIPTS variable correctly to point to Moses scripts." + exit +fi + +src=en +tgt=de +lang=en-de +prep=$OUTDIR +tmp=$prep/tmp +orig=orig +dev=dev/newstest2013 + +mkdir -p $orig $tmp $prep + +cd $orig + +for ((i=0;i<${#URLS[@]};++i)); do + file=${FILES[i]} + if [ -f $file ]; then + echo "$file already exists, skipping download" + else + url=${URLS[i]} + wget "$url" + if [ -f $file ]; then + echo "$url successfully downloaded." + else + echo "$url not successfully downloaded." + exit -1 + fi + if [ ${file: -4} == ".tgz" ]; then + tar zxvf $file + elif [ ${file: -4} == ".tar" ]; then + tar xvf $file + fi + fi +done +cd .. + +echo "pre-processing train data..." +for l in $src $tgt; do + rm $tmp/train.tags.$lang.tok.$l + for f in "${CORPORA[@]}"; do + cat $orig/$f.$l | \ + perl $NORM_PUNC $l | \ + perl $REM_NON_PRINT_CHAR | \ + perl $TOKENIZER -threads 8 -a -l $l >> $tmp/train.tags.$lang.tok.$l + done +done + +echo "pre-processing test data..." +for l in $src $tgt; do + if [ "$l" == "$src" ]; then + t="src" + else + t="ref" + fi + grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \ + sed -e 's/<seg id="[0-9]*">\s*//g' | \ + sed -e 's/\s*<\/seg>\s*//g' | \ + sed -e "s/\’/\'/g" | \ + perl $TOKENIZER -threads 8 -a -l $l > $tmp/test.$l + echo "" +done + +echo "splitting train and valid..." +for l in $src $tgt; do + awk '{if (NR%100 == 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/valid.$l + awk '{if (NR%100 != 0) print $0; }' $tmp/train.tags.$lang.tok.$l > $tmp/train.$l +done + +TRAIN=$tmp/train.de-en +BPE_CODE=$prep/code +rm -f $TRAIN +for l in $src $tgt; do + cat $tmp/train.$l >> $TRAIN +done + +echo "learn_bpe.py on ${TRAIN}..." +python $BPEROOT/learn_bpe.py -s $BPE_TOKENS < $TRAIN > $BPE_CODE + +for L in $src $tgt; do + for f in train.$L valid.$L test.$L; do + echo "apply_bpe.py to ${f}..." + python $BPEROOT/apply_bpe.py -c $BPE_CODE < $tmp/$f > $tmp/bpe.$f + done +done + +perl $CLEAN -ratio 1.5 $tmp/bpe.train $src $tgt $prep/train 1 250 +perl $CLEAN -ratio 1.5 $tmp/bpe.valid $src $tgt $prep/valid 1 250 + +for L in $src $tgt; do + cp $tmp/bpe.test.$L $prep/test.$L +done diff --git a/fairseq/examples/unsupervised_quality_estimation/README.md b/fairseq/examples/unsupervised_quality_estimation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e86a0d13b883af0c37fdc2c1fee9b0b9dff4d18c --- /dev/null +++ b/fairseq/examples/unsupervised_quality_estimation/README.md @@ -0,0 +1,126 @@ +# Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020) + +This page includes instructions for reproducing results from the paper [Unsupervised Quality Estimation for Neural +Machine Translation (Fomicheva et al., 2020)](https://arxiv.org/abs/2005.10608) + +## Requirements: + +* mosesdecoder: https://github.com/moses-smt/mosesdecoder +* subword-nmt: https://github.com/rsennrich/subword-nmt +* flores: https://github.com/facebookresearch/flores + +## Download Models and Test Data + +Download translation models and test data from [MLQE dataset repository](https://github.com/facebookresearch/mlqe). + +## Set up: + +Given a testset consisting of source sentences and reference translations: + +* `SRC_LANG`: source language +* `TGT_LANG`: target language +* `INPUT`: input prefix, such that the file `$INPUT.$SRC_LANG` contains source sentences and `$INPUT.$TGT_LANG` +contains the reference sentences +* `OUTPUT_DIR`: output path to store results +* `MOSES_DECODER`: path to mosesdecoder installation +* `BPE_ROOT`: path to subword-nmt installation +* `BPE`: path to BPE model +* `MODEL_DIR`: directory containing the NMT model `.pt` file as well as the source and target vocabularies. +* `TMP`: directory for intermediate temporary files +* `GPU`: if translating with GPU, id of the GPU to use for inference +* `DROPOUT_N`: number of stochastic forward passes + +`$DROPOUT_N` is set to 30 in the experiments reported in the paper. However, we observed that increasing it beyond 10 +does not bring substantial improvements. + +## Translate the data using standard decoding + +Preprocess the input data: +``` +for LANG in $SRC_LANG $TGT_LANG; do + perl $MOSES_DECODER/scripts/tokenizer/tokenizer.perl -threads 80 -a -l $LANG < $INPUT.$LANG > $TMP/preprocessed.tok.$LANG + python $BPE_ROOT/apply_bpe.py -c ${BPE} < $TMP/preprocessed.tok.$LANG > $TMP/preprocessed.tok.bpe.$LANG +done +``` + +Binarize the data for faster translation: + +``` +fairseq-preprocess --srcdict $MODEL_DIR/dict.$SRC_LANG.txt --tgtdict $MODEL_DIR/dict.$TGT_LANG.txt +--source-lang ${SRC_LANG} --target-lang ${TGT_LANG} --testpref $TMP/preprocessed.tok.bpe --destdir $TMP/bin --workers 4 +``` + +Translate + +``` +CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5 +--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out +grep ^H $TMP/fairseq.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/mt.out +``` + +Post-process + +``` +sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/mt.out | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl +-l $TGT_LANG > $OUTPUT_DIR/mt.out +``` + +## Produce uncertainty estimates + +### Scoring + +Make temporary files to store the translations repeated N times. + +``` +python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/preprocessed.tok.bpe.$SRC_LANG -n $DROPOUT_N +-o $TMP/repeated.$SRC_LANG +python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/mt.out -n $DROPOUT_N -o $TMP/repeated.$TGT_LANG + +fairseq-preprocess --srcdict ${MODEL_DIR}/dict.${SRC_LANG}.txt $TGT_DIC --source-lang ${SRC_LANG} +--target-lang ${TGT_LANG} --testpref ${TMP}/repeated --destdir ${TMP}/bin-repeated +``` + +Produce model scores for the generated translations using `--retain-dropout` option to apply dropout at inference time: + +``` +CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5 + --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout + --retain-dropout-modules '["TransformerModel","TransformerEncoder","TransformerDecoder","TransformerEncoderLayer"]' + TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out + +grep ^H $TMP/dropout.scoring.out | cut -d- -f2- | sort -n | cut -f2 > $TMP/dropout.scores + +``` + +Use `--retain-dropout-modules` to specify the modules. By default, dropout is applied in the same places +as for training. + +Compute the mean of the resulting output distribution: + +``` +python $SCRIPTS/scripts/uncertainty/aggregate_scores.py -i $TMP/dropout.scores -o $OUTPUT_DIR/dropout.scores.mean +-n $DROPOUT_N +``` + +### Generation + +Produce multiple translation hypotheses for the same source using `--retain-dropout` option: + +``` +CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt + --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --retain-dropout + --unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder +TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out + +grep ^H $TMP/dropout.generation.out | cut -d- -f2- | sort -n | cut -f3- > $TMP/dropout.hypotheses_ + +sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl +-l $TGT_LANG > $TMP/dropout.hypotheses +``` + +Compute similarity between multiple hypotheses corresponding to the same source sentence using Meteor +evaluation metric: +``` +python meteor.py -i $TMP/dropout.hypotheses -m <path_to_meteor_installation> -n $DROPOUT_N -o +$OUTPUT_DIR/dropout.gen.sim.meteor +``` diff --git a/fairseq/examples/wav2vec/README.md b/fairseq/examples/wav2vec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e97973307505e902307aa24b3885800378193b5d --- /dev/null +++ b/fairseq/examples/wav2vec/README.md @@ -0,0 +1,426 @@ +# wav2vec 2.0 + +wav2vec 2.0 learns speech representations on unlabeled data as described in [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations (Baevski et al., 2020)](https://arxiv.org/abs/2006.11477). + +We learned speech representations in multiple languages as well in [Unsupervised Cross-lingual Representation Learning for Speech Recognition (Conneau et al., 2020)](https://arxiv.org/abs/2006.13979). + +We also combined wav2vec 2.0 with self-training in [Self-training and Pre-training are Complementary for Speech Recognition (Xu et al., 2020)](https://arxiv.org/abs/2010.11430). + +We combined speech data from multiple domains in [Robust wav2vec 2.0: Analyzing Domain Shift in Self-Supervised Pre-Training (Hsu, et al., 2021)](https://arxiv.org/abs/2104.01027). + +We finetuned XLSR-53 on multiple languages to transcribe unseen languages in [Simple and Effective Zero-shot Cross-lingual Phoneme Recognition (Xu et al., 2021)](https://arxiv.org/abs/2109.11680). + +## Pre-trained models + +Model | Finetuning split | Dataset | Model +|---|---|---|--- +Wav2Vec 2.0 Base | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small.pt) +Wav2Vec 2.0 Base | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_10m.pt) +Wav2Vec 2.0 Base | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_100h.pt) +Wav2Vec 2.0 Base | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_small_960h.pt) +Wav2Vec 2.0 Large | No finetuning | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/libri960_big.pt) +Wav2Vec 2.0 Large | 10 minutes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_10m.pt) +Wav2Vec 2.0 Large | 100 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_100h.pt) +Wav2Vec 2.0 Large | 960 hours | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_big_960h.pt) +Wav2Vec 2.0 Large (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_new.pt) +Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_no_FT) +Wav2Vec 2.0 Large conformer - rope (LV-60)* | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_no_FT) +Wav2Vec 2.0 Large (LV-60)* | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_new.pt) +Wav2Vec 2.0 Large (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_new.pt) +Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_100h_FT.pt) +Wav2Vec 2.0 Large conformer - rope (LV-60)* | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_100h_FT.pt) +Wav2Vec 2.0 Large (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec2_vox_960h_new.pt) +Wav2Vec 2.0 Large conformer - rel_pos (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_relpos_PT_960h_FT.pt) +Wav2Vec 2.0 Large conformer - rope (LV-60)* | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) | [download](s3://dl.fbaipublicfiles.com/fairseq/conformer/wav2vec2/librilight/LL_rope_PT_960h_FT.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 10 minutes | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_10m_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 100 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_100h_pl.pt) +Wav2Vec 2.0 Large (LV-60) + Self Training * | 960 hours | [Libri-Light](https://github.com/facebookresearch/libri-light) + [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt) +Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | No finetuning | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv.pt) +Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | 960 hours Librispeech | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv_ftls960_updated.pt) +Wav2Vec 2.0 Large (LV-60 + CV + SWBD + FSH) ** | 300 hours Switchboard | [Libri-Light](https://github.com/facebookresearch/libri-light) + [CommonVoice](https://commonvoice.mozilla.org/en/languages) + [Switchboard](https://catalog.ldc.upenn.edu/LDC97S62) + [Fisher](https://catalog.ldc.upenn.edu/LDC2004T19) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/w2v_large_lv_fsh_swbd_cv_ftsb300_updated.pt) + +\* updated (Oct. 24, 2020)\ +** updated (Nov. 13, 2021) + +We also release multilingual pre-trained wav2vec 2.0 (XLSR) models: + +Model | Architecture | Hours | Languages | Datasets | Model +|---|---|---|---|---|--- +XLSR-53 | Large | 56k | 53 | MLS, CommonVoice, BABEL | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/xlsr_53_56k.pt) + +The XLSR model uses the following datasets for multilingual pretraining: + +* **[MLS: Multilingual LibriSpeech](https://indico2.conference4me.psnc.pl/event/35/contributions/3585/attachments/1060/1101/Wed-2-6-10.pdf)** (8 languages, 50.7k hours): *Dutch, English, French, German, Italian, Polish, Portuguese, Spanish* + +* **[CommonVoice](https://commonvoice.mozilla.org/en/languages)** (36 languages, 3.6k hours): *Arabic, Basque, Breton, Chinese (CN), Chinese (HK), Chinese (TW), Chuvash, Dhivehi, Dutch, English, Esperanto, Estonian, French, German, Hakh-Chin, Indonesian, Interlingua, Irish, Italian, Japanese, Kabyle, Kinyarwanda, Kyrgyz, Latvian, Mongolian, Persian, Portuguese, Russian, Sakha, Slovenian, Spanish, Swedish, Tamil, Tatar, Turkish, Welsh* (see also [finetuning splits]([https://dl.fbaipublicfiles.com/cpc_audio/common_voices_splits.tar.gz]) from [this paper](https://arxiv.org/abs/2002.02848)). + +* **[Babel](https://catalog.ldc.upenn.edu/byyear)** (17 languages, 1.7k hours): *Assamese, Bengali, Cantonese, Cebuano, Georgian, Haitian, Kazakh, Kurmanji, Lao, Pashto, Swahili, Tagalog, Tamil, Tok, Turkish, Vietnamese, Zulu* + +We also finetuned several models on languages from [CommonVoice](https://commonvoice.mozilla.org/en/languages) (version 6.1) and [Babel](https://catalog.ldc.upenn.edu/byyear). Please refer to [our paper](https://arxiv.org/abs/2109.11680) for details about which languages are used. + +Pretrained Model | Fintune Dataset | # Languages | Phonemizer | Model | Dictionary +|---|---|---|---|---|--- +LV-60 | CommonVoice | 26 | [Espeak](https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/espeak_en_26lang_m10.pt) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/espeak_dict.txt) +XLSR-53 | CommonVoice | 26 | [Espeak](https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/espeak_26lang_m10.pt) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/espeak_dict.txt) +XLSR-53 | CommonVoice | 21 | [Phonetisaurus](https://github.com/AdolfVonKleist/Phonetisaurus) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/phonetisaurus_21lang_m10.pt) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/phonetisaurus_dict.txt) +XLSR-53 | CommonVoice, BABEL | 21, 19 | [Phonetisaurus](https://github.com/AdolfVonKleist/Phonetisaurus) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/phonetisaurus_40lang_m10.pt) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/zero_shot/phonetisaurus_40lang.dict.txt) + +We release 2 models that are finetuned on data from 2 different phonemizers. Although the phonemes are all [IPA](https://en.wikipedia.org/wiki/International_Phonetic_Alphabet) symbols, there are still subtle differences between the phonemized transcriptions from the 2 phonemizers. Thus, it's better to use the corresponding model, if your data is phonemized by either phonemizer above. + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest + +First, install the `soundfile` library: + +```shell script +pip install soundfile +``` + +Next, run: + +```shell script +python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext $ext --valid-percent $valid +``` + +$ext should be set to flac, wav, or whatever format your dataset happens to use that soundfile can read. + +$valid should be set to some reasonable percentage (like 0.01) of training data to use for validation. +To use a pre-defined validation set (like dev-other from librispeech), set to it 0 and then overwrite valid.tsv with a +separately pre-processed manifest file. + +### Train a wav2vec 2.0 base model + +This configuration was used for the base model trained on the Librispeech dataset in the wav2vec 2.0 paper + +Note that the input is expected to be single channel, sampled at 16 kHz + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_base_librispeech +``` + +Note: you can simulate 64 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 64/k + +### Train a wav2vec 2.0 large model + +This configuration was used for the large model trained on the Libri-light dataset in the wav2vec 2.0 paper + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox +``` + +Note: you can simulate 128 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 128/k + +### Train a wav2vec 2.0 model with conformer backbone + +To replace the transformer layers in the encoder with the conformer layers, set `--layer-type conformer --attn-type espnet --pos-enc-type ${POS_ENC_TYPE}`. `POS_ENC_TYPE` refers to positional encoding to be used in the conformer encoder. +Set it to `abs`, `rope` or `rel_pos` to use the absolute positional encoding, rotary positional encoding or relative positional encoding in the conformer layer respectively. + +To train a base model with conformer: + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_conformer_base_librispeech \ + --attn-type espnet --pos-enc-type ${POS_ENC_TYPE} +``` + +To train a large model with conformer: + +```shell script +$ fairseq-hydra-train \ + task.data=/path/to/data \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_conformer_large_librivox + --attn-type espnet --pos-enc-type ${POS_ENC_TYPE} + +``` + +### Fine-tune a pre-trained model with CTC + +Fine-tuning a model requires parallel audio and labels file, as well as a vocabulary file in fairseq format. +A letter vocabulary can be downloaded [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). +An example [script](libri_labels.py) that generates labels for the Librispeech dataset from the tsv file produced by wav2vec_manifest.py can be used as follows: + +```shell script +split=train +$ python libri_labels.py /path/to/tsv --output-dir /output/dir --output-name $split +``` + +Fine-tuning on 100h of Librispeech with letter targets: + +```shell script +$ fairseq-hydra-train \ + distributed_training.distributed_port=$PORT \ + task.data=/path/to/data \ + model.w2v_path=/path/to/model.pt \ + --config-dir /path/to/fairseq-py/examples/wav2vec/config/finetuning \ + --config-name base_100h +``` + +There are other config files in the config/finetuning directory that can be used to fine-tune on other splits. +You can specify the right config via the `--config-name` parameter. + +Note: you can simulate 24 GPUs by using k GPUs and adding command line parameters (before `--config-dir`) +`distributed_training.distributed_world_size=k` `+optimization.update_freq='[x]'` where x = 24/k + +Decoding with a language model during training requires flashlight [python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter). +If you want to use a language model, add `+criterion.wer_args='[/path/to/kenlm, /path/to/lexicon, 2, -1]'` to the command line. + +### Evaluating a CTC model + +Evaluating a CTC model with a language model requires [flashlight python bindings](https://github.com/facebookresearch/flashlight/tree/master/bindings/python) (previously called [wav2letter](https://github.com/facebookresearch/wav2letter) to be installed. + +Fairseq transformer language model used in the wav2vec 2.0 paper can be obtained from the [wav2letter model repository](https://github.com/facebookresearch/wav2letter/tree/master/recipes/sota/2019). +Be sure to upper-case the language model vocab after downloading it. + +Letter dictionary for pre-trained models can be found [here](https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt). + +Next, run the evaluation command: + +```shell script +$subset=dev_other +python examples/speech_recognition/infer.py /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw --task audio_finetuning \ +--nbest 1 --path /path/to/model --gen-subset $subset --results-path /path/to/save/results/for/sclite --w2l-decoder kenlm \ +--lm-model /path/to/kenlm.bin --lm-weight 2 --word-score -1 --sil-weight 0 --criterion ctc --labels ltr --max-tokens 4000000 \ +--post-process letter +``` + +To get raw numbers, use --w2l-decoder viterbi and omit the lexicon. To use the transformer language model, use --w2l-decoder fairseqlm. + +## Use wav2vec 2.0 with 🤗Transformers + +Wav2Vec2 is also available in the [🤗Transformers library](https://github.com/huggingface/transformers) since version 4.4. + +Pretrained Models can be found on the [hub](https://huggingface.co/models?filter=wav2vec2) +and documentation can be found [here](https://huggingface.co/transformers/master/model_doc/wav2vec2.html). + +Usage example: + +```python +# !pip install transformers +# !pip install datasets +import soundfile as sf +import torch +from datasets import load_dataset +from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor + +# load pretrained model +processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") +model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h") + + +librispeech_samples_ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation") + +# load audio +audio_input, sample_rate = sf.read(librispeech_samples_ds[0]["file"]) + +# pad input values and return pt tensor +input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values + +# INFERENCE + +# retrieve logits & take argmax +logits = model(input_values).logits +predicted_ids = torch.argmax(logits, dim=-1) + +# transcribe +transcription = processor.decode(predicted_ids[0]) + +# FINE-TUNE + +target_transcription = "A MAN SAID TO THE UNIVERSE I EXIST" + +# encode labels +with processor.as_target_processor(): + labels = processor(target_transcription, return_tensors="pt").input_ids + +# compute loss by passing labels +loss = model(input_values, labels=labels).loss +loss.backward() +``` + +# wav2vec + +Example to train a wav2vec model as described in [wav2vec: Unsupervised Pre-training for Speech Recognition (Schneider et al., 2019)](https://arxiv.org/abs/1904.05862). + +## Pre-trained models + +Description | Dataset | Model +---|---|--- +Wav2Vec large | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_large.pt) + +#### Example usage + +```python +import torch +import fairseq + +cp_path = '/path/to/wav2vec.pt' +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) +model = model[0] +model.eval() + +wav_input_16khz = torch.randn(1,10000) +z = model.feature_extractor(wav_input_16khz) +c = model.feature_aggregator(z) +``` + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate files 10 to 30 seconds in length) + +### Prepare training data manifest + +``` +python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a wav2vec model + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test +``` + +### Run wav2vec2 pre-training on Google Cloud TPUs + +Wav2Vec2 is now supported on TPUs! It's currently pre-training only. + +#### Using hydra on a v3-8 + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu.yaml +``` + +#### Using command line arguments on a v3-8 + +Note: Commandline arguments way of execution has a [known-problem](https://github.com/pytorch/fairseq/issues/3741) currently. + +``` +$ OMP_NUM_THREADS=1 python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size 8 --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + +#### Using hydra on a pod slice (v3-N with N > 8) + +``` +$ OMP_NUM_THREADS=1 fairseq-hydra-train \ + task.data=/manifest/path \ + --config-dir /PATH/TO/FAIRSEQ/examples/wav2vec/config/pretraining \ + --config-name wav2vec2_large_librivox_tpu-pod.yaml # edit distributed-world-size accordingly +``` + +#### Using command line arguments on a pod slice (v3-N with N > 8) + +Note: Commandline arguments way of execution has a [known-problem](https://github.com/pytorch/fairseq/issues/3741) currently. + +``` +$ python -m torch_xla.distributed.xla_dist \ + --tpu ${TPUNAME} --conda-env=torch-xla-${TORCH_XLA_VERSION} --env OMP_NUM_THREADS=1 \ + -- \ +python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \ +--arch wav2vec2 --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \ +--max-sample-size 150000 --max-tokens 1500000 --skip-invalid-size-inputs-valid-test \ +--tpu --distributed-world-size ${WORLD_SIZE} --num-batch-buckets 3 --enable-padding \ +--encoder-layerdrop 0 --mask-channel-prob 0.1 +``` + +### Extract embeddings from the downstream task data + +``` +$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/wav2vec_featurize.py --input /path/to/task/waves --output /path/to/output \ +--model /model/path/checkpoint_best.pt --split train valid test +``` + +# vq-wav2vec + +Example to train a vq-wav2vec model as described in [vq-wav2vec: Self-Supervised Learning of Discrete Speech Representations (Baevski et al., 2019)](https://arxiv.org/abs/1910.05453). + +These models are also used in [Effectiveness of self-supervised pre-training for speech recognition (Baevski et al., 2019)](https://arxiv.org/abs/1911.03912). + +## Pre-trained models + +Description | Dataset | Model +---|---|--- +vq-wav2vec Gumbel | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec.pt) +vq-wav2vec K-means | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/vq-wav2vec_kmeans.pt) +Roberta on K-means codes | [Librispeech](http://www.openslr.org/12) | [download](https://dl.fbaipublicfiles.com/fairseq/wav2vec/bert_kmeans.tar) + +#### Example usage + +```python +import torch +import fairseq + +cp = torch.load('/path/to/vq-wav2vec.pt') +model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp]) +model = model[0] +model.eval() + +wav_input_16khz = torch.randn(1,10000) +z = model.feature_extractor(wav_input_16khz) +_, idxs = model.vector_quantizer.forward_idx(z) +print(idxs.shape) # output: torch.Size([1, 60, 2]), 60 timesteps with 2 indexes corresponding to 2 groups in the model +``` + +## Training a new model with the CLI tools + +Given a directory containing wav files to be used for pretraining (we recommend splitting each file into separate file 10 to 30 seconds in length) + +### Prepare training data manifest + +``` +python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/path --ext wav +``` + +### Train a gumbel vq-wav2vec model + +``` +$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \ +--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \ +--optimizer adam --lr 1e-05 --lr-scheduler cosine \ +--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \ +--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \ +--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \ +--log-keys ["prob_perplexity","code_perplexity","temp"] --vq-type gumbel --vq-groups 2 --vq-depth 2 \ +--combine-groups --vq-vars 320 --vq-temp (2,0.5,0.999995) --prediction-steps 12 --warmup-updates 1000 \ +--warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 --max-sample-size 150000 \ +--max-tokens 300000 --cross-sample-negatives 0 --update-freq 1 --seed 2 --skip-invalid-size-inputs-valid-test +``` + +for k-means training, set vq-type with "kmeans" and add --loss-weights [1] argument. Pre-trained models were trained on 16 GPUs. + +### Tokenize audio data (e.g. for BERT training) + +``` +$ PYTHONPATH=/path/to/fairseq python examples/wav2vec/vq-wav2vec_featurize.py --data-dir /manifest/path --output-dir /path/to/output \ +--checkpoint /model/path/checkpoint_best.pt --split train valid test --extension tsv +``` diff --git a/fairseq/examples/wav2vec/config/finetuning/base_100h.yaml b/fairseq/examples/wav2vec/config/finetuning/base_100h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..153b5df1708f187933a68ff55009652a107c69eb --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/base_100h.yaml @@ -0,0 +1,58 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 diff --git a/fairseq/examples/wav2vec/config/finetuning/base_10h.yaml b/fairseq/examples/wav2vec/config/finetuning/base_10h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5044518025aadf73b3f5f204a74969cf3f982f03 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/base_10h.yaml @@ -0,0 +1,63 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 50 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 50 + valid_subset: dev_other + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 2 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 20000 + lr: [0.00005] + sentence_avg: true + update_freq: [4] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.05 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 diff --git a/fairseq/examples/wav2vec/config/finetuning/base_960h.yaml b/fairseq/examples/wav2vec/config/finetuning/base_960h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3eadc36b37b5caa7c0f0838e4e9ff2d401110429 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/base_960h.yaml @@ -0,0 +1,57 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: false + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 3200000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 320000 + lr: [0.0001] + sentence_avg: true + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.1 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 0 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_100h.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_100h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b8f81e5e1895e448c6126f9ccd6d9c4a5fb36b18 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_100h.yaml @@ -0,0 +1,58 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + valid_subset: dev_other + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 80000 + lr: [0.00003] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_channel_prob: 0.5 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_100h_2.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_100h_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9bf588f58789f56cd97e9f72749480b24661c429 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_100h_2.yaml @@ -0,0 +1,106 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: 0 + wer_sil_weight: -2 + +optimization: + max_update: 100000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a0d517ebba5a675697d101434db02b9a2721a89 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_100h_2_aws.yaml @@ -0,0 +1,82 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /data/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx-wav2vec/abaevski/data/libri/100h/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /fsx-wav2vec/abaevski/data/libri/4-gram.bin + wer_lexicon: /fsx-wav2vec/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: 0 + wer_sil_weight: -2 + +optimization: + max_update: 100000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 82000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 7 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_100h_3.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_100h_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..46778666f6d92705127f59f0c7b5e45560477da6 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_100h_3.yaml @@ -0,0 +1,101 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1000000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 1 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 100000 + lr: [1e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 8000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.4 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_10h.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_10h.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8f1ca71ee2bd195934f85010e141772479769f0b --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_10h.yaml @@ -0,0 +1,63 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 50 + save_interval_updates: 10000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 50 + valid_subset: dev_other + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 20000 + lr: [0.0001] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.75 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_10h_2.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_10h_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..05ee76f147edf1f377fa6879c6d2bbe7bfd12ad9 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_10h_2.yaml @@ -0,0 +1,102 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 10 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + keep_interval_updates: 1 + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 10 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 60000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 8000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.5 + mask_length: 5 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58ad2acf7147e01def5dd0f5191f281f34cc4060 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_10h_aws_v100.yaml @@ -0,0 +1,102 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 +# tensorboard_logdir: tb + +checkpoint: + save_interval: 10 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /fsx/abaevski/data/libri/10h/wav2vec/raw + labels: ltr + cache_in_scratch: true + + +dataset: + num_workers: 10 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 100 + validate_interval: 10 + valid_subset: dev_other + required_batch_size_multiple: 1 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_lexicon: /fsx/abaevski/data/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 60000 + lr: [2e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [1] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: null + warmup_steps: 8000 + hold_steps: 0 + decay_steps: 72000 + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.6 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.1 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 100 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /fsx/${env:USER}/w2v_ft/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 0 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: learnfair + max_num_timeout: 30 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_10m.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_10m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..07e327fe7477c88545d7744a5928d2fe703dd6de --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_10m.yaml @@ -0,0 +1,63 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + +checkpoint: + save_interval: 1000 + save_interval_updates: 50 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: ??? + normalize: true + labels: ltr + +dataset: + num_workers: 6 + max_tokens: 1280000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 1000 + valid_subset: dev_other + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 4 + +criterion: + _name: ctc + zero_infinity: true + +optimization: + max_update: 13000 + lr: [0.0001] + sentence_avg: true + update_freq: [5] + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_channel_prob: 0.25 + mask_channel_length: 64 + layerdrop: 0.1 + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_1h_3.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_1h_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..842c89717e841130934282434359299d755182c9 --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_1h_3.yaml @@ -0,0 +1,104 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 500 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 640000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 100 + valid_subset: dev_other + required_batch_size_multiple: 8 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 6 + wer_word_score: -0.1 + wer_sil_weight: -4.7 + +optimization: + max_update: 13000 + lr: [6e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [5] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: cosine + warmup_updates: 4000 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.3 + mask_length: 3 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/fairseq/examples/wav2vec/config/finetuning/vox_1h_4.yaml b/fairseq/examples/wav2vec/config/finetuning/vox_1h_4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..698ed8c4dafcf387b620ae58f6feeeb7384190ef --- /dev/null +++ b/fairseq/examples/wav2vec/config/finetuning/vox_1h_4.yaml @@ -0,0 +1,104 @@ +# @package _group_ + +common: + fp16: true + log_format: json + log_interval: 200 + user_dir: /private/home/abaevski/fairseq-py/examples/data2vec +# tensorboard_logdir: tb + +checkpoint: + save_interval: 100 + save_interval_updates: 1000 + keep_interval_updates: 1 + no_epoch_checkpoints: true + best_checkpoint_metric: wer + +task: + _name: audio_finetuning + data: /checkpoint/abaevski/data/speech/libri/1h/wav2vec/raw + labels: ltr + normalize: true + +dataset: + num_workers: 6 + max_tokens: 640000 + skip_invalid_size_inputs_valid_test: true + validate_after_updates: 10000 + validate_interval: 100 + valid_subset: dev_other + required_batch_size_multiple: 8 + +distributed_training: + ddp_backend: legacy_ddp + distributed_world_size: 8 + +criterion: + _name: ctc + zero_infinity: true + post_process: letter + wer_kenlm_model: /checkpoint/abaevski/data/speech/libri/4-gram.bin + wer_lexicon: /checkpoint/abaevski/data/speech/libri/10h/wav2vec/raw/lexicon_ltr2.lst + wer_lm_weight: 2.0 + wer_word_score: -1.0 + +optimization: + max_update: 13000 + lr: [6e-5] +# lr: [1e-5] # base 10h wer + sentence_avg: true + update_freq: [5] # base 10h we -> 2/4 + +optimizer: + _name: adam + adam_betas: (0.9,0.98) + adam_eps: 1e-08 + +lr_scheduler: + _name: tri_stage + phase_ratio: [0.1, 0.4, 0.5] + final_lr_scale: 0.05 + +model: + _name: wav2vec_ctc + w2v_path: ??? + apply_mask: true + mask_prob: 0.65 + mask_length: 10 +# mask_prob: 0.65 # base 10h wer + mask_channel_prob: 0.25 +# mask_channel_prob: 0.6 # base 10h wer + mask_channel_length: 64 + layerdrop: 0.1 +# layerdrop: 0.05 # base 10h wer + activation_dropout: 0.1 + feature_grad_mult: 0.0 + freeze_finetune_updates: 10000 + dropout: 0 + final_dropout: 0 + attention_dropout: 0 + +hydra: + job: + config: + override_dirname: + kv_sep: ':' + item_sep: '__' + exclude_keys: + - run_config + - distributed_training.distributed_port + sweep: + dir: /checkpoint/${env:USER}/${env:PREFIX}/${hydra.job.config_name}/${hydra.job.override_dirname} + subdir: ${hydra.job.num} + launcher: + submitit_folder: ${hydra.sweep.dir} + timeout_min: 3000 + cpus_per_task: 10 + gpus_per_node: 4 + tasks_per_node: 4 + mem_gb: 250 + nodes: 1 + name: ${env:PREFIX}_${hydra.job.config_name} + partition: devlab,learnlab,learnfair,scavenge + constraint: volta32gb + max_num_timeout: 30 diff --git a/fairseq/examples/wav2vec/vq-wav2vec_featurize.py b/fairseq/examples/wav2vec/vq-wav2vec_featurize.py new file mode 100644 index 0000000000000000000000000000000000000000..627072ee174c22831209e00984b945eb9dc2c279 --- /dev/null +++ b/fairseq/examples/wav2vec/vq-wav2vec_featurize.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset +""" + +import argparse +import glob +import os +import os.path as osp +import pprint + +import soundfile as sf +import torch +import fairseq +from torch import nn +from torch.utils.data import DataLoader + + +try: + import tqdm +except: + print("Install tqdm to use --log-format=tqdm") + + +class FilesDataset: + def __init__(self, files, labels): + self.files = files + if labels and osp.exists(labels): + with open(labels, "r") as lbl_f: + self.labels = [line.rstrip() for line in lbl_f] + else: + self.labels = labels + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + fname = self.files[index] + + wav, sr = sf.read(fname) + assert sr == 16000 + + wav = torch.from_numpy(wav).float() + lbls = None + if self.labels: + if isinstance(self.labels, str): + lbl_file = osp.splitext(fname)[0] + "." + self.labels + with open(lbl_file, "r") as lblf: + lbls = lblf.readline() + assert lbls is not None + else: + lbls = self.labels[index] + return wav, lbls + + def collate(self, batch): + return batch + + +class ArgTypes: + @staticmethod + def existing_path(arg): + arg = str(arg) + assert osp.exists(arg), f"File {arg} does not exist" + return arg + + @staticmethod + def mkdir(arg): + arg = str(arg) + os.makedirs(arg, exist_ok=True) + return arg + + +class DatasetWriter: + def __init__(self): + + self.args = self.load_config() + pprint.pprint(self.args.__dict__) + + self.model = self.load_model() + + def __getattr__(self, attr): + return getattr(self.args, attr) + + def read_manifest(self, fname): + + with open(fname, "r") as fp: + lines = fp.read().split("\n") + root = lines.pop(0).strip() + fnames = [ + osp.join(root, line.split("\t")[0]) for line in lines if len(line) > 0 + ] + + return fnames + + def process_splits(self): + + if self.args.shard is not None or self.args.num_shards is not None: + assert self.args.shard is not None and self.args.num_shards is not None + + for split in self.splits: + print(split) + + if self.extension == "tsv": + datadir = osp.join(self.data_dir, f"{split}.{self.extension}") + print("Reading manifest file: ", datadir) + files = self.read_manifest(datadir) + else: + datadir = osp.join(self.data_dir, split, f"**/*.{self.extension}") + files = glob.glob(datadir, recursive=True) + + assert len(files) > 0 + + if self.args.shard is not None: + files = files[self.args.shard :: self.args.num_shards] + + lbls = [] + with open(self.data_file(split), "w") as srcf: + for line, lbl in self.iterate(files): + print(line, file=srcf) + if self.args.labels: + lbls.append(lbl + "\n") + + if self.args.labels: + assert all(a is not None for a in lbls) + with open(self.lbl_file(split), "w") as lblf: + lblf.writelines(lbls) + + def iterate(self, files): + + data = self.load_data(files) + for samples in tqdm.tqdm(data, total=len(files) // 32): + + for wav, lbl in samples: + x = wav.unsqueeze(0).float().cuda() + + div = 1 + while x.size(-1) // div > self.args.max_size: + div += 1 + + xs = x.chunk(div, dim=-1) + + result = [] + for x in xs: + torch.cuda.empty_cache() + x = self.model.feature_extractor(x) + if self.quantize_location == "encoder": + with torch.no_grad(): + _, idx = self.model.vector_quantizer.forward_idx(x) + idx = idx.squeeze(0).cpu() + else: + with torch.no_grad(): + z = self.model.feature_aggregator(x) + _, idx = self.model.vector_quantizer.forward_idx(z) + idx = idx.squeeze(0).cpu() + result.append(idx) + + idx = torch.cat(result, dim=0) + yield " ".join("-".join(map(str, a.tolist())) for a in idx), lbl + + def lbl_file(self, name): + shard_part = "" if self.args.shard is None else f".{self.args.shard}" + return osp.join(self.output_dir, f"{name}.lbl{shard_part}") + + def data_file(self, name): + shard_part = "" if self.args.shard is None else f".{self.args.shard}" + return osp.join(self.output_dir, f"{name}.src{shard_part}") + + def var_file(self): + return osp.join(self.output_dir, f"vars.pt") + + def load_config(self): + + parser = argparse.ArgumentParser("Vector Quantized wav2vec features") + + # Model Arguments + parser.add_argument("--checkpoint", type=ArgTypes.existing_path, required=True) + parser.add_argument("--data-parallel", action="store_true") + + # Output Arguments + parser.add_argument("--output-dir", type=ArgTypes.mkdir, required=True) + + # Data Arguments + parser.add_argument("--data-dir", type=ArgTypes.existing_path, required=True) + parser.add_argument("--splits", type=str, nargs="+", required=True) + parser.add_argument("--extension", type=str, required=True) + parser.add_argument("--labels", type=str, required=False) + + parser.add_argument("--shard", type=int, default=None) + parser.add_argument("--num-shards", type=int, default=None) + parser.add_argument("--max-size", type=int, default=1300000) + + # Logger Arguments + parser.add_argument( + "--log-format", type=str, choices=["none", "simple", "tqdm"] + ) + + return parser.parse_args() + + def load_data(self, fnames): + + dataset = FilesDataset(fnames, self.args.labels) + loader = DataLoader( + dataset, batch_size=32, collate_fn=dataset.collate, num_workers=8 + ) + return loader + + def load_model(self): + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([self.checkpoint]) + model = model[0] + + self.quantize_location = getattr(cfg.model, "vq", "encoder") + + model.eval().float() + model.cuda() + + if self.data_parallel: + model = nn.DataParallel(model) + + return model + + def __call__(self): + + self.process_splits() + + if hasattr(self.model.feature_extractor, "vars") and ( + self.args.shard is None or self.args.shard == 0 + ): + vars = ( + self.model.feature_extractor.vars.view( + self.model.feature_extractor.banks, + self.model.feature_extractor.num_vars, + -1, + ) + .cpu() + .detach() + ) + print("writing learned latent variable embeddings: ", vars.shape) + torch.save(vars, self.var_file()) + + +if __name__ == "__main__": + write_data = DatasetWriter() + + write_data() + print("Done.") diff --git a/fairseq/examples/wav2vec/wav2vec_featurize.py b/fairseq/examples/wav2vec/wav2vec_featurize.py new file mode 100644 index 0000000000000000000000000000000000000000..588268b7080cbd3400ac144604b2d75cef2876dd --- /dev/null +++ b/fairseq/examples/wav2vec/wav2vec_featurize.py @@ -0,0 +1,249 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Helper script to pre-compute embeddings for a flashlight (previously called wav2letter++) dataset +""" + +import argparse +import glob +import os +from shutil import copy + +import h5py +import numpy as np +import soundfile as sf +import torch +import tqdm +import fairseq +from torch import nn + + +def read_audio(fname): + """ Load an audio file and return PCM along with the sample rate """ + + wav, sr = sf.read(fname) + assert sr == 16e3 + + return wav, 16e3 + + +class PretrainedWav2VecModel(nn.Module): + def __init__(self, fname): + super().__init__() + + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fname]) + model = model[0] + model.eval() + + self.model = model + + def forward(self, x): + with torch.no_grad(): + z = self.model.feature_extractor(x) + if isinstance(z, tuple): + z = z[0] + c = self.model.feature_aggregator(z) + return z, c + + +class EmbeddingWriterConfig(argparse.ArgumentParser): + def __init__(self): + super().__init__("Pre-compute embeddings for flashlight datasets") + + kwargs = {"action": "store", "type": str, "required": True} + + self.add_argument("--input", "-i", help="Input Directory", **kwargs) + self.add_argument("--output", "-o", help="Output Directory", **kwargs) + self.add_argument("--model", help="Path to model checkpoint", **kwargs) + self.add_argument("--split", help="Dataset Splits", nargs="+", **kwargs) + self.add_argument( + "--ext", default="wav", required=False, help="Audio file extension" + ) + + self.add_argument( + "--no-copy-labels", + action="store_true", + help="Do not copy label files. Useful for large datasets, use --targetdir in flashlight then.", + ) + self.add_argument( + "--use-feat", + action="store_true", + help="Use the feature vector ('z') instead of context vector ('c') for features", + ) + self.add_argument("--gpu", help="GPU to use", default=0, type=int) + + +class Prediction: + """ Lightweight wrapper around a fairspeech embedding model """ + + def __init__(self, fname, gpu=0): + self.gpu = gpu + self.model = PretrainedWav2VecModel(fname).cuda(gpu) + + def __call__(self, x): + x = torch.from_numpy(x).float().cuda(self.gpu) + with torch.no_grad(): + z, c = self.model(x.unsqueeze(0)) + + return z.squeeze(0).cpu().numpy(), c.squeeze(0).cpu().numpy() + + +class H5Writer: + """ Write features as hdf5 file in flashlight compatible format """ + + def __init__(self, fname): + self.fname = fname + os.makedirs(os.path.dirname(self.fname), exist_ok=True) + + def write(self, data): + channel, T = data.shape + + with h5py.File(self.fname, "w") as out_ds: + data = data.T.flatten() + out_ds["features"] = data + out_ds["info"] = np.array([16e3 // 160, T, channel]) + + +class EmbeddingDatasetWriter(object): + """Given a model and a flashlight dataset, pre-compute and store embeddings + + Args: + input_root, str : + Path to the flashlight dataset + output_root, str : + Desired output directory. Will be created if non-existent + split, str : + Dataset split + """ + + def __init__( + self, + input_root, + output_root, + split, + model_fname, + extension="wav", + gpu=0, + verbose=False, + use_feat=False, + ): + + assert os.path.exists(model_fname) + + self.model_fname = model_fname + self.model = Prediction(self.model_fname, gpu) + + self.input_root = input_root + self.output_root = output_root + self.split = split + self.verbose = verbose + self.extension = extension + self.use_feat = use_feat + + assert os.path.exists(self.input_path), "Input path '{}' does not exist".format( + self.input_path + ) + + def _progress(self, iterable, **kwargs): + if self.verbose: + return tqdm.tqdm(iterable, **kwargs) + return iterable + + def require_output_path(self, fname=None): + path = self.get_output_path(fname) + os.makedirs(path, exist_ok=True) + + @property + def input_path(self): + return self.get_input_path() + + @property + def output_path(self): + return self.get_output_path() + + def get_input_path(self, fname=None): + if fname is None: + return os.path.join(self.input_root, self.split) + return os.path.join(self.get_input_path(), fname) + + def get_output_path(self, fname=None): + if fname is None: + return os.path.join(self.output_root, self.split) + return os.path.join(self.get_output_path(), fname) + + def copy_labels(self): + self.require_output_path() + + labels = list( + filter( + lambda x: self.extension not in x, glob.glob(self.get_input_path("*")) + ) + ) + for fname in tqdm.tqdm(labels): + copy(fname, self.output_path) + + @property + def input_fnames(self): + return sorted(glob.glob(self.get_input_path("*.{}".format(self.extension)))) + + def __len__(self): + return len(self.input_fnames) + + def write_features(self): + + paths = self.input_fnames + + fnames_context = map( + lambda x: os.path.join( + self.output_path, x.replace("." + self.extension, ".h5context") + ), + map(os.path.basename, paths), + ) + + for name, target_fname in self._progress( + zip(paths, fnames_context), total=len(self) + ): + wav, sr = read_audio(name) + z, c = self.model(wav) + feat = z if self.use_feat else c + writer = H5Writer(target_fname) + writer.write(feat) + + def __repr__(self): + + return "EmbeddingDatasetWriter ({n_files} files)\n\tinput:\t{input_root}\n\toutput:\t{output_root}\n\tsplit:\t{split})".format( + n_files=len(self), **self.__dict__ + ) + + +if __name__ == "__main__": + + args = EmbeddingWriterConfig().parse_args() + + for split in args.split: + + writer = EmbeddingDatasetWriter( + input_root=args.input, + output_root=args.output, + split=split, + model_fname=args.model, + gpu=args.gpu, + extension=args.ext, + use_feat=args.use_feat, + ) + + print(writer) + writer.require_output_path() + + print("Writing Features...") + writer.write_features() + print("Done.") + + if not args.no_copy_labels: + print("Copying label data...") + writer.copy_labels() + print("Done.") diff --git a/fairseq/examples/wav2vec/wav2vec_manifest.py b/fairseq/examples/wav2vec/wav2vec_manifest.py new file mode 100644 index 0000000000000000000000000000000000000000..9b8aa180e88d9ee98bdca7089aed5046ec0d9cb9 --- /dev/null +++ b/fairseq/examples/wav2vec/wav2vec_manifest.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Data pre-processing: build vocabularies and binarize training data. +""" + +import argparse +import glob +import os +import random + +import soundfile + + +def get_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "root", metavar="DIR", help="root directory containing flac files to index" + ) + parser.add_argument( + "--valid-percent", + default=0.01, + type=float, + metavar="D", + help="percentage of data to use as validation set (between 0 and 1)", + ) + parser.add_argument( + "--dest", default=".", type=str, metavar="DIR", help="output directory" + ) + parser.add_argument( + "--ext", default="flac", type=str, metavar="EXT", help="extension to look for" + ) + parser.add_argument("--seed", default=42, type=int, metavar="N", help="random seed") + parser.add_argument( + "--path-must-contain", + default=None, + type=str, + metavar="FRAG", + help="if set, path must contain this substring for a file to be included in the manifest", + ) + return parser + + +def main(args): + assert args.valid_percent >= 0 and args.valid_percent <= 1.0 + + if not os.path.exists(args.dest): + os.makedirs(args.dest) + + dir_path = os.path.realpath(args.root) + search_path = os.path.join(dir_path, "**/*." + args.ext) + rand = random.Random(args.seed) + + valid_f = ( + open(os.path.join(args.dest, "valid.tsv"), "w") + if args.valid_percent > 0 + else None + ) + + with open(os.path.join(args.dest, "train.tsv"), "w") as train_f: + print(dir_path, file=train_f) + + if valid_f is not None: + print(dir_path, file=valid_f) + + for fname in glob.iglob(search_path, recursive=True): + file_path = os.path.realpath(fname) + + if args.path_must_contain and args.path_must_contain not in file_path: + continue + + frames = soundfile.info(fname).frames + dest = train_f if rand.random() > args.valid_percent else valid_f + print( + "{}\t{}".format(os.path.relpath(file_path, dir_path), frames), file=dest + ) + if valid_f is not None: + valid_f.close() + + +if __name__ == "__main__": + parser = get_parser() + args = parser.parse_args() + main(args)