Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/examples/backtranslation/README.md +297 -0
- fairseq-0.10.2/examples/backtranslation/deduplicate_lines.py +41 -0
- fairseq-0.10.2/examples/backtranslation/extract_bt_data.py +72 -0
- fairseq-0.10.2/examples/backtranslation/prepare-de-monolingual.sh +98 -0
- fairseq-0.10.2/examples/backtranslation/sacrebleu.sh +37 -0
- fairseq-0.10.2/examples/backtranslation/tokenized_bleu.sh +46 -0
- fairseq-0.10.2/examples/bart/README.glue.md +99 -0
- fairseq-0.10.2/examples/bart/README.md +218 -0
- fairseq-0.10.2/examples/bart/README.summarization.md +121 -0
- fairseq-0.10.2/examples/criss/download_and_preprocess_flores_test.sh +64 -0
- fairseq-0.10.2/examples/criss/download_and_preprocess_tatoeba.sh +37 -0
- fairseq-0.10.2/examples/criss/mining/mine.py +233 -0
- fairseq-0.10.2/examples/criss/mining/mine_example.sh +103 -0
- fairseq-0.10.2/examples/criss/save_encoder.py +213 -0
- fairseq-0.10.2/examples/criss/sentence_retrieval/encoder_analysis.py +92 -0
- fairseq-0.10.2/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh +59 -0
- fairseq-0.10.2/examples/criss/unsupervised_mt/eval.sh +37 -0
- fairseq-0.10.2/examples/cross_lingual_language_model/README.md +77 -0
- fairseq-0.10.2/examples/latent_depth/latent_depth_src/loss/latent_depth.py +99 -0
- fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/__init__.py +0 -0
- fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py +59 -0
- fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_transformer.py +146 -0
- fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/__init__.py +0 -0
- fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/latent_layers.py +86 -0
- fairseq-0.10.2/examples/m2m_100/README.md +209 -0
- fairseq-0.10.2/examples/m2m_100/install_dependecies.sh +78 -0
- fairseq-0.10.2/examples/m2m_100/process_data/remove_too_much_punc.py +36 -0
- fairseq-0.10.2/examples/m2m_100/tok.sh +83 -0
- fairseq-0.10.2/examples/mbart/README.md +123 -0
- fairseq-0.10.2/examples/roberta/README.glue.md +99 -0
- fairseq-0.10.2/examples/roberta/README.md +296 -0
- fairseq-0.10.2/examples/roberta/README.pretraining.md +98 -0
- fairseq-0.10.2/examples/roberta/README.race.md +68 -0
- fairseq-0.10.2/examples/roberta/multiprocessing_bpe_encoder.py +130 -0
- fairseq-0.10.2/examples/roberta/preprocess_GLUE_tasks.sh +185 -0
- fairseq-0.10.2/examples/roberta/preprocess_RACE.sh +59 -0
- fairseq-0.10.2/examples/speech_recognition/README.md +106 -0
- fairseq-0.10.2/examples/speech_recognition/__init__.py +1 -0
- fairseq-0.10.2/examples/speech_recognition/criterions/__init__.py +17 -0
- fairseq-0.10.2/examples/speech_recognition/criterions/cross_entropy_acc.py +130 -0
- fairseq-0.10.2/examples/speech_recognition/datasets/asr_prep_json.py +125 -0
- fairseq-0.10.2/examples/speech_recognition/datasets/prepare-librispeech.sh +88 -0
- fairseq-0.10.2/examples/speech_recognition/infer.py +464 -0
- fairseq-0.10.2/examples/speech_recognition/utils/wer_utils.py +381 -0
- fairseq-0.10.2/examples/speech_recognition/w2l_decoder.py +435 -0
- fairseq-0.10.2/examples/speech_to_text/data_utils.py +262 -0
- fairseq-0.10.2/examples/speech_to_text/prep_covost_data.py +294 -0
- fairseq-0.10.2/examples/speech_to_text/prep_librispeech_data.py +119 -0
- fairseq-0.10.2/examples/speech_to_text/prep_mustc_data.py +200 -0
- fairseq-0.10.2/examples/unsupervised_quality_estimation/README.md +126 -0
fairseq-0.10.2/examples/backtranslation/README.md
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Understanding Back-Translation at Scale (Edunov et al., 2018)
|
| 2 |
+
|
| 3 |
+
This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
|
| 4 |
+
|
| 5 |
+
## Pre-trained models
|
| 6 |
+
|
| 7 |
+
Model | Description | Dataset | Download
|
| 8 |
+
---|---|---|---
|
| 9 |
+
`transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
|
| 10 |
+
|
| 11 |
+
## Example usage (torch.hub)
|
| 12 |
+
|
| 13 |
+
We require a few additional Python dependencies for preprocessing:
|
| 14 |
+
```bash
|
| 15 |
+
pip install subword_nmt sacremoses
|
| 16 |
+
```
|
| 17 |
+
|
| 18 |
+
Then to generate translations from the full model ensemble:
|
| 19 |
+
```python
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
# List available models
|
| 23 |
+
torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
|
| 24 |
+
|
| 25 |
+
# Load the WMT'18 En-De ensemble
|
| 26 |
+
en2de_ensemble = torch.hub.load(
|
| 27 |
+
'pytorch/fairseq', 'transformer.wmt18.en-de',
|
| 28 |
+
checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
|
| 29 |
+
tokenizer='moses', bpe='subword_nmt')
|
| 30 |
+
|
| 31 |
+
# The ensemble contains 5 models
|
| 32 |
+
len(en2de_ensemble.models)
|
| 33 |
+
# 5
|
| 34 |
+
|
| 35 |
+
# Translate
|
| 36 |
+
en2de_ensemble.translate('Hello world!')
|
| 37 |
+
# 'Hallo Welt!'
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Training your own model (WMT'18 English-German)
|
| 41 |
+
|
| 42 |
+
The following instructions can be adapted to reproduce the models from the paper.
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
#### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
|
| 46 |
+
|
| 47 |
+
First download and preprocess the data:
|
| 48 |
+
```bash
|
| 49 |
+
# Download and prepare the data
|
| 50 |
+
cd examples/backtranslation/
|
| 51 |
+
bash prepare-wmt18en2de.sh
|
| 52 |
+
cd ../..
|
| 53 |
+
|
| 54 |
+
# Binarize the data
|
| 55 |
+
TEXT=examples/backtranslation/wmt18_en_de
|
| 56 |
+
fairseq-preprocess \
|
| 57 |
+
--joined-dictionary \
|
| 58 |
+
--source-lang en --target-lang de \
|
| 59 |
+
--trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
|
| 60 |
+
--destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
|
| 61 |
+
--workers 20
|
| 62 |
+
|
| 63 |
+
# Copy the BPE code into the data-bin directory for future use
|
| 64 |
+
cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
(Optionally) Train a baseline model (English-German) using just the parallel data:
|
| 68 |
+
```bash
|
| 69 |
+
CHECKPOINT_DIR=checkpoints_en_de_parallel
|
| 70 |
+
fairseq-train --fp16 \
|
| 71 |
+
data-bin/wmt18_en_de \
|
| 72 |
+
--source-lang en --target-lang de \
|
| 73 |
+
--arch transformer_wmt_en_de_big --share-all-embeddings \
|
| 74 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 75 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 76 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 77 |
+
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 78 |
+
--max-tokens 3584 --update-freq 16 \
|
| 79 |
+
--max-update 30000 \
|
| 80 |
+
--save-dir $CHECKPOINT_DIR
|
| 81 |
+
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
|
| 82 |
+
# different number of GPUs.
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Average the last 10 checkpoints:
|
| 86 |
+
```bash
|
| 87 |
+
python scripts/average_checkpoints.py \
|
| 88 |
+
--inputs $CHECKPOINT_DIR \
|
| 89 |
+
--num-epoch-checkpoints 10 \
|
| 90 |
+
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Evaluate BLEU:
|
| 94 |
+
```bash
|
| 95 |
+
# tokenized BLEU on newstest2017:
|
| 96 |
+
bash examples/backtranslation/tokenized_bleu.sh \
|
| 97 |
+
wmt17 \
|
| 98 |
+
en-de \
|
| 99 |
+
data-bin/wmt18_en_de \
|
| 100 |
+
data-bin/wmt18_en_de/code \
|
| 101 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 102 |
+
# BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
|
| 103 |
+
# compare to 29.46 in Table 1, which is also for tokenized BLEU
|
| 104 |
+
|
| 105 |
+
# generally it's better to report (detokenized) sacrebleu though:
|
| 106 |
+
bash examples/backtranslation/sacrebleu.sh \
|
| 107 |
+
wmt17 \
|
| 108 |
+
en-de \
|
| 109 |
+
data-bin/wmt18_en_de \
|
| 110 |
+
data-bin/wmt18_en_de/code \
|
| 111 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 112 |
+
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
#### Step 2. Back-translate monolingual German data
|
| 117 |
+
|
| 118 |
+
Train a reverse model (German-English) to do the back-translation:
|
| 119 |
+
```bash
|
| 120 |
+
CHECKPOINT_DIR=checkpoints_de_en_parallel
|
| 121 |
+
fairseq-train --fp16 \
|
| 122 |
+
data-bin/wmt18_en_de \
|
| 123 |
+
--source-lang de --target-lang en \
|
| 124 |
+
--arch transformer_wmt_en_de_big --share-all-embeddings \
|
| 125 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 126 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 127 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 128 |
+
--lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 129 |
+
--max-tokens 3584 --update-freq 16 \
|
| 130 |
+
--max-update 30000 \
|
| 131 |
+
--save-dir $CHECKPOINT_DIR
|
| 132 |
+
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
|
| 133 |
+
# different number of GPUs.
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
Let's evaluate the back-translation (BT) model to make sure it is well trained:
|
| 137 |
+
```bash
|
| 138 |
+
bash examples/backtranslation/sacrebleu.sh \
|
| 139 |
+
wmt17 \
|
| 140 |
+
de-en \
|
| 141 |
+
data-bin/wmt18_en_de \
|
| 142 |
+
data-bin/wmt18_en_de/code \
|
| 143 |
+
$CHECKPOINT_DIR/checkpoint_best.py
|
| 144 |
+
# BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
|
| 145 |
+
# compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Next prepare the monolingual data:
|
| 149 |
+
```bash
|
| 150 |
+
# Download and prepare the monolingual data
|
| 151 |
+
# By default the script samples 25M monolingual sentences, which after
|
| 152 |
+
# deduplication should be just over 24M sentences. These are split into 25
|
| 153 |
+
# shards, each with 1M sentences (except for the last shard).
|
| 154 |
+
cd examples/backtranslation/
|
| 155 |
+
bash prepare-de-monolingual.sh
|
| 156 |
+
cd ../..
|
| 157 |
+
|
| 158 |
+
# Binarize each shard of the monolingual data
|
| 159 |
+
TEXT=examples/backtranslation/wmt18_de_mono
|
| 160 |
+
for SHARD in $(seq -f "%02g" 0 24); do \
|
| 161 |
+
fairseq-preprocess \
|
| 162 |
+
--only-source \
|
| 163 |
+
--source-lang de --target-lang en \
|
| 164 |
+
--joined-dictionary \
|
| 165 |
+
--srcdict data-bin/wmt18_en_de/dict.de.txt \
|
| 166 |
+
--testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
|
| 167 |
+
--destdir data-bin/wmt18_de_mono/shard${SHARD} \
|
| 168 |
+
--workers 20; \
|
| 169 |
+
cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
|
| 170 |
+
done
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
Now we're ready to perform back-translation over the monolingual data. The
|
| 174 |
+
following command generates via sampling, but it's possible to use greedy
|
| 175 |
+
decoding (`--beam 1`), beam search (`--beam 5`),
|
| 176 |
+
top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
|
| 177 |
+
```bash
|
| 178 |
+
mkdir backtranslation_output
|
| 179 |
+
for SHARD in $(seq -f "%02g" 0 24); do \
|
| 180 |
+
fairseq-generate --fp16 \
|
| 181 |
+
data-bin/wmt18_de_mono/shard${SHARD} \
|
| 182 |
+
--path $CHECKPOINT_DIR/checkpoint_best.pt \
|
| 183 |
+
--skip-invalid-size-inputs-valid-test \
|
| 184 |
+
--max-tokens 4096 \
|
| 185 |
+
--sampling --beam 1 \
|
| 186 |
+
> backtranslation_output/sampling.shard${SHARD}.out; \
|
| 187 |
+
done
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
|
| 191 |
+
the back-translations and apply length ratio filters:
|
| 192 |
+
```bash
|
| 193 |
+
python examples/backtranslation/extract_bt_data.py \
|
| 194 |
+
--minlen 1 --maxlen 250 --ratio 1.5 \
|
| 195 |
+
--output backtranslation_output/bt_data --srclang en --tgtlang de \
|
| 196 |
+
backtranslation_output/sampling.shard*.out
|
| 197 |
+
|
| 198 |
+
# Ensure lengths are the same:
|
| 199 |
+
# wc -l backtranslation_output/bt_data.{en,de}
|
| 200 |
+
# 21795614 backtranslation_output/bt_data.en
|
| 201 |
+
# 21795614 backtranslation_output/bt_data.de
|
| 202 |
+
# 43591228 total
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
Binarize the filtered BT data and combine it with the parallel data:
|
| 206 |
+
```bash
|
| 207 |
+
TEXT=backtranslation_output
|
| 208 |
+
fairseq-preprocess \
|
| 209 |
+
--source-lang en --target-lang de \
|
| 210 |
+
--joined-dictionary \
|
| 211 |
+
--srcdict data-bin/wmt18_en_de/dict.en.txt \
|
| 212 |
+
--trainpref $TEXT/bt_data \
|
| 213 |
+
--destdir data-bin/wmt18_en_de_bt \
|
| 214 |
+
--workers 20
|
| 215 |
+
|
| 216 |
+
# We want to train on the combined data, so we'll symlink the parallel + BT data
|
| 217 |
+
# in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
|
| 218 |
+
# and the BT data as "train1", so that fairseq will combine them automatically
|
| 219 |
+
# and so that we can use the `--upsample-primary` option to upsample the
|
| 220 |
+
# parallel data (if desired).
|
| 221 |
+
PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
|
| 222 |
+
BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
|
| 223 |
+
COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
|
| 224 |
+
mkdir -p $COMB_DATA
|
| 225 |
+
for LANG in en de; do \
|
| 226 |
+
ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
|
| 227 |
+
for EXT in bin idx; do \
|
| 228 |
+
ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
|
| 229 |
+
ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
|
| 230 |
+
ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
|
| 231 |
+
ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
|
| 232 |
+
done; \
|
| 233 |
+
done
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
#### 3. Train an English-German model over the combined parallel + BT data
|
| 238 |
+
|
| 239 |
+
Finally we can train a model over the parallel + BT data:
|
| 240 |
+
```bash
|
| 241 |
+
CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
|
| 242 |
+
fairseq-train --fp16 \
|
| 243 |
+
data-bin/wmt18_en_de_para_plus_bt \
|
| 244 |
+
--upsample-primary 16 \
|
| 245 |
+
--source-lang en --target-lang de \
|
| 246 |
+
--arch transformer_wmt_en_de_big --share-all-embeddings \
|
| 247 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 248 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 249 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 250 |
+
--lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
|
| 251 |
+
--max-tokens 3584 --update-freq 16 \
|
| 252 |
+
--max-update 100000 \
|
| 253 |
+
--save-dir $CHECKPOINT_DIR
|
| 254 |
+
# Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
|
| 255 |
+
# different number of GPUs.
|
| 256 |
+
```
|
| 257 |
+
|
| 258 |
+
Average the last 10 checkpoints:
|
| 259 |
+
```bash
|
| 260 |
+
python scripts/average_checkpoints.py \
|
| 261 |
+
--inputs $CHECKPOINT_DIR \
|
| 262 |
+
--num-epoch-checkpoints 10 \
|
| 263 |
+
--output $CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
Evaluate BLEU:
|
| 267 |
+
```bash
|
| 268 |
+
# tokenized BLEU on newstest2017:
|
| 269 |
+
bash examples/backtranslation/tokenized_bleu.sh \
|
| 270 |
+
wmt17 \
|
| 271 |
+
en-de \
|
| 272 |
+
data-bin/wmt18_en_de \
|
| 273 |
+
data-bin/wmt18_en_de/code \
|
| 274 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 275 |
+
# BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
|
| 276 |
+
# compare to 32.35 in Table 1, which is also for tokenized BLEU
|
| 277 |
+
|
| 278 |
+
# generally it's better to report (detokenized) sacrebleu:
|
| 279 |
+
bash examples/backtranslation/sacrebleu.sh \
|
| 280 |
+
wmt17 \
|
| 281 |
+
en-de \
|
| 282 |
+
data-bin/wmt18_en_de \
|
| 283 |
+
data-bin/wmt18_en_de/code \
|
| 284 |
+
$CHECKPOINT_DIR/checkpoint.avg10.pt
|
| 285 |
+
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
|
| 286 |
+
```
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
## Citation
|
| 290 |
+
```bibtex
|
| 291 |
+
@inproceedings{edunov2018backtranslation,
|
| 292 |
+
title = {Understanding Back-Translation at Scale},
|
| 293 |
+
author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
|
| 294 |
+
booktitle = {Conference of the Association for Computational Linguistics (ACL)},
|
| 295 |
+
year = 2018,
|
| 296 |
+
}
|
| 297 |
+
```
|
fairseq-0.10.2/examples/backtranslation/deduplicate_lines.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import fileinput
|
| 9 |
+
import hashlib
|
| 10 |
+
import sys
|
| 11 |
+
from multiprocessing import Pool
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_hashes_and_lines(raw_line):
|
| 15 |
+
hash = hashlib.md5(raw_line).hexdigest()
|
| 16 |
+
return hash, raw_line
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main():
|
| 20 |
+
parser = argparse.ArgumentParser()
|
| 21 |
+
parser.add_argument("--workers", type=int, default=10)
|
| 22 |
+
parser.add_argument("files", nargs="*", help="input files")
|
| 23 |
+
args = parser.parse_args()
|
| 24 |
+
|
| 25 |
+
seen = set()
|
| 26 |
+
with fileinput.input(args.files, mode="rb") as h:
|
| 27 |
+
pool = Pool(args.workers)
|
| 28 |
+
results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
|
| 29 |
+
for i, (hash, raw_line) in enumerate(results):
|
| 30 |
+
if hash not in seen:
|
| 31 |
+
seen.add(hash)
|
| 32 |
+
sys.stdout.buffer.write(raw_line)
|
| 33 |
+
if i % 1000000 == 0:
|
| 34 |
+
print(i, file=sys.stderr, end="", flush=True)
|
| 35 |
+
elif i % 100000 == 0:
|
| 36 |
+
print(".", file=sys.stderr, end="", flush=True)
|
| 37 |
+
print(file=sys.stderr, flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
fairseq-0.10.2/examples/backtranslation/extract_bt_data.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import fileinput
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def main():
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
description=(
|
| 16 |
+
"Extract back-translations from the stdout of fairseq-generate. "
|
| 17 |
+
"If there are multiply hypotheses for a source, we only keep the first one. "
|
| 18 |
+
)
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument("--output", required=True, help="output prefix")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--srclang", required=True, help="source language (extracted from H-* lines)"
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--tgtlang", required=True, help="target language (extracted from S-* lines)"
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument("--minlen", type=int, help="min length filter")
|
| 28 |
+
parser.add_argument("--maxlen", type=int, help="max length filter")
|
| 29 |
+
parser.add_argument("--ratio", type=float, help="ratio filter")
|
| 30 |
+
parser.add_argument("files", nargs="*", help="input files")
|
| 31 |
+
args = parser.parse_args()
|
| 32 |
+
|
| 33 |
+
def validate(src, tgt):
|
| 34 |
+
srclen = len(src.split(" ")) if src != "" else 0
|
| 35 |
+
tgtlen = len(tgt.split(" ")) if tgt != "" else 0
|
| 36 |
+
if (
|
| 37 |
+
(args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
|
| 38 |
+
or (
|
| 39 |
+
args.maxlen is not None
|
| 40 |
+
and (srclen > args.maxlen or tgtlen > args.maxlen)
|
| 41 |
+
)
|
| 42 |
+
or (
|
| 43 |
+
args.ratio is not None
|
| 44 |
+
and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
|
| 45 |
+
)
|
| 46 |
+
):
|
| 47 |
+
return False
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
def safe_index(toks, index, default):
|
| 51 |
+
try:
|
| 52 |
+
return toks[index]
|
| 53 |
+
except IndexError:
|
| 54 |
+
return default
|
| 55 |
+
|
| 56 |
+
with open(args.output + "." + args.srclang, "w") as src_h, open(
|
| 57 |
+
args.output + "." + args.tgtlang, "w"
|
| 58 |
+
) as tgt_h:
|
| 59 |
+
for line in tqdm(fileinput.input(args.files)):
|
| 60 |
+
if line.startswith("S-"):
|
| 61 |
+
tgt = safe_index(line.rstrip().split("\t"), 1, "")
|
| 62 |
+
elif line.startswith("H-"):
|
| 63 |
+
if tgt is not None:
|
| 64 |
+
src = safe_index(line.rstrip().split("\t"), 2, "")
|
| 65 |
+
if validate(src, tgt):
|
| 66 |
+
print(src, file=src_h)
|
| 67 |
+
print(tgt, file=tgt_h)
|
| 68 |
+
tgt = None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == "__main__":
|
| 72 |
+
main()
|
fairseq-0.10.2/examples/backtranslation/prepare-de-monolingual.sh
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
SCRIPTS=mosesdecoder/scripts
|
| 4 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
| 5 |
+
NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
|
| 6 |
+
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
|
| 7 |
+
BPEROOT=subword-nmt/subword_nmt
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
BPE_CODE=wmt18_en_de/code
|
| 11 |
+
SUBSAMPLE_SIZE=25000000
|
| 12 |
+
LANG=de
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
OUTDIR=wmt18_${LANG}_mono
|
| 16 |
+
orig=orig
|
| 17 |
+
tmp=$OUTDIR/tmp
|
| 18 |
+
mkdir -p $OUTDIR $tmp
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
URLS=(
|
| 22 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
|
| 23 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
|
| 24 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
|
| 25 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
|
| 26 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
|
| 27 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
|
| 28 |
+
"http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
|
| 29 |
+
"http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
|
| 30 |
+
"http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
|
| 31 |
+
"http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
|
| 32 |
+
"http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
|
| 33 |
+
)
|
| 34 |
+
FILES=(
|
| 35 |
+
"news.2007.de.shuffled.gz"
|
| 36 |
+
"news.2008.de.shuffled.gz"
|
| 37 |
+
"news.2009.de.shuffled.gz"
|
| 38 |
+
"news.2010.de.shuffled.gz"
|
| 39 |
+
"news.2011.de.shuffled.gz"
|
| 40 |
+
"news.2012.de.shuffled.gz"
|
| 41 |
+
"news.2013.de.shuffled.gz"
|
| 42 |
+
"news.2014.de.shuffled.v2.gz"
|
| 43 |
+
"news.2015.de.shuffled.gz"
|
| 44 |
+
"news.2016.de.shuffled.gz"
|
| 45 |
+
"news.2017.de.shuffled.deduped.gz"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
cd $orig
|
| 50 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
| 51 |
+
file=${FILES[i]}
|
| 52 |
+
if [ -f $file ]; then
|
| 53 |
+
echo "$file already exists, skipping download"
|
| 54 |
+
else
|
| 55 |
+
url=${URLS[i]}
|
| 56 |
+
wget "$url"
|
| 57 |
+
fi
|
| 58 |
+
done
|
| 59 |
+
cd ..
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
|
| 63 |
+
echo "found monolingual sample, skipping shuffle/sample/tokenize"
|
| 64 |
+
else
|
| 65 |
+
gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
|
| 66 |
+
| shuf -n $SUBSAMPLE_SIZE \
|
| 67 |
+
| perl $NORM_PUNC $LANG \
|
| 68 |
+
| perl $REM_NON_PRINT_CHAR \
|
| 69 |
+
| perl $TOKENIZER -threads 8 -a -l $LANG \
|
| 70 |
+
> $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
|
| 71 |
+
fi
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
|
| 75 |
+
echo "found BPE monolingual sample, skipping BPE step"
|
| 76 |
+
else
|
| 77 |
+
python $BPEROOT/apply_bpe.py -c $BPE_CODE \
|
| 78 |
+
< $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
|
| 79 |
+
> $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
|
| 80 |
+
fi
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
|
| 84 |
+
echo "found deduplicated monolingual sample, skipping deduplication step"
|
| 85 |
+
else
|
| 86 |
+
python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
|
| 87 |
+
> $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
|
| 88 |
+
fi
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
|
| 92 |
+
echo "found sharded data, skipping sharding step"
|
| 93 |
+
else
|
| 94 |
+
split --lines 1000000 --numeric-suffixes \
|
| 95 |
+
--additional-suffix .${LANG} \
|
| 96 |
+
$tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
|
| 97 |
+
$OUTDIR/bpe.monolingual.dedup.
|
| 98 |
+
fi
|
fairseq-0.10.2/examples/backtranslation/sacrebleu.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
if [ $# -ne 5 ]; then
|
| 4 |
+
echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
|
| 5 |
+
exit
|
| 6 |
+
fi
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
DATASET=$1
|
| 10 |
+
LANGPAIR=$2
|
| 11 |
+
DATABIN=$3
|
| 12 |
+
BPECODE=$4
|
| 13 |
+
MODEL=$5
|
| 14 |
+
|
| 15 |
+
SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
|
| 16 |
+
TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
|
| 20 |
+
if [ ! -e $BPEROOT ]; then
|
| 21 |
+
BPEROOT=subword-nmt/subword_nmt
|
| 22 |
+
if [ ! -e $BPEROOT ]; then
|
| 23 |
+
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
|
| 24 |
+
git clone https://github.com/rsennrich/subword-nmt.git
|
| 25 |
+
fi
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
sacrebleu -t $DATASET -l $LANGPAIR --echo src \
|
| 30 |
+
| sacremoses tokenize -a -l $SRCLANG -q \
|
| 31 |
+
| python $BPEROOT/apply_bpe.py -c $BPECODE \
|
| 32 |
+
| fairseq-interactive $DATABIN --path $MODEL \
|
| 33 |
+
-s $SRCLANG -t $TGTLANG \
|
| 34 |
+
--beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
|
| 35 |
+
| grep ^H- | cut -f 3- \
|
| 36 |
+
| sacremoses detokenize -l $TGTLANG -q \
|
| 37 |
+
| sacrebleu -t $DATASET -l $LANGPAIR
|
fairseq-0.10.2/examples/backtranslation/tokenized_bleu.sh
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
if [ $# -ne 5 ]; then
|
| 4 |
+
echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
|
| 5 |
+
exit
|
| 6 |
+
fi
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
DATASET=$1
|
| 10 |
+
LANGPAIR=$2
|
| 11 |
+
DATABIN=$3
|
| 12 |
+
BPECODE=$4
|
| 13 |
+
MODEL=$5
|
| 14 |
+
|
| 15 |
+
SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
|
| 16 |
+
TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
|
| 20 |
+
if [ ! -e $BPEROOT ]; then
|
| 21 |
+
BPEROOT=subword-nmt/subword_nmt
|
| 22 |
+
if [ ! -e $BPEROOT ]; then
|
| 23 |
+
echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
|
| 24 |
+
git clone https://github.com/rsennrich/subword-nmt.git
|
| 25 |
+
fi
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
TMP_REF=$(mktemp)
|
| 30 |
+
|
| 31 |
+
sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
|
| 32 |
+
| sacremoses normalize -l $TGTLANG -q \
|
| 33 |
+
| sacremoses tokenize -a -l $TGTLANG -q \
|
| 34 |
+
> $TMP_REF
|
| 35 |
+
|
| 36 |
+
sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
|
| 37 |
+
| sacremoses normalize -l $SRCLANG -q \
|
| 38 |
+
| sacremoses tokenize -a -l $SRCLANG -q \
|
| 39 |
+
| python $BPEROOT/apply_bpe.py -c $BPECODE \
|
| 40 |
+
| fairseq-interactive $DATABIN --path $MODEL \
|
| 41 |
+
-s $SRCLANG -t $TGTLANG \
|
| 42 |
+
--beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
|
| 43 |
+
| grep ^H- | cut -f 3- \
|
| 44 |
+
| fairseq-score --ref $TMP_REF
|
| 45 |
+
|
| 46 |
+
rm -f $TMP_REF
|
fairseq-0.10.2/examples/bart/README.glue.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-tuning BART on GLUE tasks
|
| 2 |
+
|
| 3 |
+
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
|
| 4 |
+
```bash
|
| 5 |
+
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
|
| 6 |
+
python download_glue_data.py --data_dir glue_data --tasks all
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
### 2) Preprocess GLUE task data (same as RoBERTa):
|
| 10 |
+
```bash
|
| 11 |
+
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
|
| 12 |
+
```
|
| 13 |
+
`glue_task_name` is one of the following:
|
| 14 |
+
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
|
| 15 |
+
Use `ALL` for preprocessing all the glue tasks.
|
| 16 |
+
|
| 17 |
+
### 3) Fine-tuning on GLUE task:
|
| 18 |
+
Example fine-tuning cmd for `RTE` task
|
| 19 |
+
```bash
|
| 20 |
+
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
|
| 21 |
+
WARMUP_UPDATES=61 # 6 percent of the number of updates
|
| 22 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
| 23 |
+
NUM_CLASSES=2
|
| 24 |
+
MAX_SENTENCES=16 # Batch size.
|
| 25 |
+
BART_PATH=/path/to/bart/model.pt
|
| 26 |
+
|
| 27 |
+
CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
|
| 28 |
+
--restore-file $BART_PATH \
|
| 29 |
+
--batch-size $MAX_SENTENCES \
|
| 30 |
+
--max-tokens 4400 \
|
| 31 |
+
--task sentence_prediction \
|
| 32 |
+
--add-prev-output-tokens \
|
| 33 |
+
--layernorm-embedding \
|
| 34 |
+
--share-all-embeddings \
|
| 35 |
+
--share-decoder-input-output-embed \
|
| 36 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 37 |
+
--required-batch-size-multiple 1 \
|
| 38 |
+
--init-token 0 \
|
| 39 |
+
--arch bart_large \
|
| 40 |
+
--criterion sentence_prediction \
|
| 41 |
+
--num-classes $NUM_CLASSES \
|
| 42 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 43 |
+
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
|
| 44 |
+
--clip-norm 0.0 \
|
| 45 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
| 46 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
| 47 |
+
--max-epoch 10 \
|
| 48 |
+
--find-unused-parameters \
|
| 49 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
For each of the GLUE task, you will need to use following cmd-line arguments:
|
| 53 |
+
|
| 54 |
+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
|
| 55 |
+
---|---|---|---|---|---|---|---|---
|
| 56 |
+
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
|
| 57 |
+
`--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
|
| 58 |
+
`bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
|
| 59 |
+
`--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
|
| 60 |
+
`--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
|
| 61 |
+
|
| 62 |
+
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
|
| 63 |
+
|
| 64 |
+
**Note:**
|
| 65 |
+
|
| 66 |
+
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
|
| 67 |
+
|
| 68 |
+
b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
|
| 69 |
+
|
| 70 |
+
### Inference on GLUE task
|
| 71 |
+
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
|
| 72 |
+
|
| 73 |
+
```python
|
| 74 |
+
from fairseq.models.bart import BARTModel
|
| 75 |
+
|
| 76 |
+
bart = BARTModel.from_pretrained(
|
| 77 |
+
'checkpoints/',
|
| 78 |
+
checkpoint_file='checkpoint_best.pt',
|
| 79 |
+
data_name_or_path='RTE-bin'
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
label_fn = lambda label: bart.task.label_dictionary.string(
|
| 83 |
+
[label + bart.task.label_dictionary.nspecial]
|
| 84 |
+
)
|
| 85 |
+
ncorrect, nsamples = 0, 0
|
| 86 |
+
bart.cuda()
|
| 87 |
+
bart.eval()
|
| 88 |
+
with open('glue_data/RTE/dev.tsv') as fin:
|
| 89 |
+
fin.readline()
|
| 90 |
+
for index, line in enumerate(fin):
|
| 91 |
+
tokens = line.strip().split('\t')
|
| 92 |
+
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
|
| 93 |
+
tokens = bart.encode(sent1, sent2)
|
| 94 |
+
prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
|
| 95 |
+
prediction_label = label_fn(prediction)
|
| 96 |
+
ncorrect += int(prediction_label == target)
|
| 97 |
+
nsamples += 1
|
| 98 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
| 99 |
+
```
|
fairseq-0.10.2/examples/bart/README.md
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
|
| 2 |
+
|
| 3 |
+
[https://arxiv.org/pdf/1910.13461.pdf]
|
| 4 |
+
|
| 5 |
+
## Introduction
|
| 6 |
+
|
| 7 |
+
BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
|
| 8 |
+
|
| 9 |
+
## Pre-trained models
|
| 10 |
+
|
| 11 |
+
Model | Description | # params | Download
|
| 12 |
+
---|---|---|---
|
| 13 |
+
`bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
|
| 14 |
+
`bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
|
| 15 |
+
`bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
|
| 16 |
+
`bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
|
| 17 |
+
`bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
|
| 18 |
+
|
| 19 |
+
## Results
|
| 20 |
+
|
| 21 |
+
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
|
| 22 |
+
_(dev set, single model, single-task finetuning)_
|
| 23 |
+
|
| 24 |
+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
|
| 25 |
+
---|---|---|---|---|---|---|---|---
|
| 26 |
+
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
|
| 27 |
+
`bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
|
| 28 |
+
|
| 29 |
+
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
|
| 30 |
+
_(dev set, no additional data used)_
|
| 31 |
+
|
| 32 |
+
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
|
| 33 |
+
---|---|---
|
| 34 |
+
`roberta.large` | 88.9/94.6 | 86.5/89.4
|
| 35 |
+
`bart.large` | 88.8/94.6 | 86.1/89.2
|
| 36 |
+
|
| 37 |
+
**[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
|
| 38 |
+
_(test set, no additional data used)_
|
| 39 |
+
|
| 40 |
+
Model | R1 | R2 | RL
|
| 41 |
+
---|---|---|---
|
| 42 |
+
`BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
|
| 43 |
+
`bart.large` | 44.16 | 21.28 | 40.90
|
| 44 |
+
|
| 45 |
+
## Example usage
|
| 46 |
+
|
| 47 |
+
##### Load BART from torch.hub (PyTorch >= 1.1):
|
| 48 |
+
```python
|
| 49 |
+
import torch
|
| 50 |
+
bart = torch.hub.load('pytorch/fairseq', 'bart.large')
|
| 51 |
+
bart.eval() # disable dropout (or leave in train mode to finetune)
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
##### Load BART (for PyTorch 1.0 or custom models):
|
| 55 |
+
```python
|
| 56 |
+
# Download bart.large model
|
| 57 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
|
| 58 |
+
tar -xzvf bart.large.tar.gz
|
| 59 |
+
|
| 60 |
+
# Load the model in fairseq
|
| 61 |
+
from fairseq.models.bart import BARTModel
|
| 62 |
+
bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
|
| 63 |
+
bart.eval() # disable dropout (or leave in train mode to finetune)
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
##### Apply Byte-Pair Encoding (BPE) to input text:
|
| 67 |
+
```python
|
| 68 |
+
tokens = bart.encode('Hello world!')
|
| 69 |
+
assert tokens.tolist() == [0, 31414, 232, 328, 2]
|
| 70 |
+
bart.decode(tokens) # 'Hello world!'
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
##### Extract features from BART:
|
| 74 |
+
```python
|
| 75 |
+
# Extract the last layer's features
|
| 76 |
+
last_layer_features = bart.extract_features(tokens)
|
| 77 |
+
assert last_layer_features.size() == torch.Size([1, 5, 1024])
|
| 78 |
+
|
| 79 |
+
# Extract all layer's features from decoder (layer 0 is the embedding layer)
|
| 80 |
+
all_layers = bart.extract_features(tokens, return_all_hiddens=True)
|
| 81 |
+
assert len(all_layers) == 13
|
| 82 |
+
assert torch.all(all_layers[-1] == last_layer_features)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
##### Use BART for sentence-pair classification tasks:
|
| 86 |
+
```python
|
| 87 |
+
# Download BART already finetuned for MNLI
|
| 88 |
+
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
|
| 89 |
+
bart.eval() # disable dropout for evaluation
|
| 90 |
+
|
| 91 |
+
# Encode a pair of sentences and make a prediction
|
| 92 |
+
tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
|
| 93 |
+
bart.predict('mnli', tokens).argmax() # 0: contradiction
|
| 94 |
+
|
| 95 |
+
# Encode another pair of sentences
|
| 96 |
+
tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
|
| 97 |
+
bart.predict('mnli', tokens).argmax() # 2: entailment
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
##### Register a new (randomly initialized) classification head:
|
| 101 |
+
```python
|
| 102 |
+
bart.register_classification_head('new_task', num_classes=3)
|
| 103 |
+
logprobs = bart.predict('new_task', tokens)
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
##### Batched prediction:
|
| 107 |
+
```python
|
| 108 |
+
import torch
|
| 109 |
+
from fairseq.data.data_utils import collate_tokens
|
| 110 |
+
|
| 111 |
+
bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
|
| 112 |
+
bart.eval()
|
| 113 |
+
|
| 114 |
+
batch_of_pairs = [
|
| 115 |
+
['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
|
| 116 |
+
['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
|
| 117 |
+
]
|
| 118 |
+
|
| 119 |
+
batch = collate_tokens(
|
| 120 |
+
[bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
logprobs = bart.predict('mnli', batch)
|
| 124 |
+
print(logprobs.argmax(dim=1))
|
| 125 |
+
# tensor([0, 2])
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
##### Using the GPU:
|
| 129 |
+
```python
|
| 130 |
+
bart.cuda()
|
| 131 |
+
bart.predict('new_task', tokens)
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
#### Evaluating the `bart.large.mnli` model:
|
| 135 |
+
|
| 136 |
+
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
|
| 137 |
+
```python
|
| 138 |
+
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
|
| 139 |
+
ncorrect, nsamples = 0, 0
|
| 140 |
+
bart.cuda()
|
| 141 |
+
bart.eval()
|
| 142 |
+
with open('glue_data/MNLI/dev_matched.tsv') as fin:
|
| 143 |
+
fin.readline()
|
| 144 |
+
for index, line in enumerate(fin):
|
| 145 |
+
tokens = line.strip().split('\t')
|
| 146 |
+
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
|
| 147 |
+
tokens = bart.encode(sent1, sent2)
|
| 148 |
+
prediction = bart.predict('mnli', tokens).argmax().item()
|
| 149 |
+
prediction_label = label_map[prediction]
|
| 150 |
+
ncorrect += int(prediction_label == target)
|
| 151 |
+
nsamples += 1
|
| 152 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
| 153 |
+
# Expected output: 0.9010
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
#### Evaluating the `bart.large.cnn` model:
|
| 157 |
+
Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
|
| 161 |
+
bart.cuda()
|
| 162 |
+
bart.eval()
|
| 163 |
+
bart.half()
|
| 164 |
+
count = 1
|
| 165 |
+
bsz = 32
|
| 166 |
+
with open('test.source') as source, open('test.hypo', 'w') as fout:
|
| 167 |
+
sline = source.readline().strip()
|
| 168 |
+
slines = [sline]
|
| 169 |
+
for sline in source:
|
| 170 |
+
if count % bsz == 0:
|
| 171 |
+
with torch.no_grad():
|
| 172 |
+
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
|
| 173 |
+
|
| 174 |
+
for hypothesis in hypotheses_batch:
|
| 175 |
+
fout.write(hypothesis + '\n')
|
| 176 |
+
fout.flush()
|
| 177 |
+
slines = []
|
| 178 |
+
|
| 179 |
+
slines.append(sline.strip())
|
| 180 |
+
count += 1
|
| 181 |
+
if slines != []:
|
| 182 |
+
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
|
| 183 |
+
for hypothesis in hypotheses_batch:
|
| 184 |
+
fout.write(hypothesis + '\n')
|
| 185 |
+
fout.flush()
|
| 186 |
+
```
|
| 187 |
+
|
| 188 |
+
Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
|
| 189 |
+
|
| 190 |
+
```bash
|
| 191 |
+
export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
|
| 192 |
+
|
| 193 |
+
# Tokenize hypothesis and target files.
|
| 194 |
+
cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
|
| 195 |
+
cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
|
| 196 |
+
files2rouge test.hypo.tokenized test.hypo.target
|
| 197 |
+
# Expected output: (ROUGE-2 Average_F: 0.21238)
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
## Finetuning
|
| 202 |
+
|
| 203 |
+
- [Finetuning on GLUE](README.glue.md)
|
| 204 |
+
- [Finetuning on CNN-DM](README.summarization.md)
|
| 205 |
+
|
| 206 |
+
## Citation
|
| 207 |
+
|
| 208 |
+
```bibtex
|
| 209 |
+
@article{lewis2019bart,
|
| 210 |
+
title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
|
| 211 |
+
Language Generation, Translation, and Comprehension},
|
| 212 |
+
author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
|
| 213 |
+
Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
|
| 214 |
+
and Luke Zettlemoyer },
|
| 215 |
+
journal={arXiv preprint arXiv:1910.13461},
|
| 216 |
+
year = {2019},
|
| 217 |
+
}
|
| 218 |
+
```
|
fairseq-0.10.2/examples/bart/README.summarization.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Fine-tuning BART on CNN-Dailymail summarization task
|
| 2 |
+
|
| 3 |
+
### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
|
| 4 |
+
|
| 5 |
+
Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
|
| 6 |
+
|
| 7 |
+
Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
|
| 8 |
+
|
| 9 |
+
### 2) BPE preprocess:
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
| 13 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
| 14 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
| 15 |
+
|
| 16 |
+
TASK=cnn_dm
|
| 17 |
+
for SPLIT in train val
|
| 18 |
+
do
|
| 19 |
+
for LANG in source target
|
| 20 |
+
do
|
| 21 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
| 22 |
+
--encoder-json encoder.json \
|
| 23 |
+
--vocab-bpe vocab.bpe \
|
| 24 |
+
--inputs "$TASK/$SPLIT.$LANG" \
|
| 25 |
+
--outputs "$TASK/$SPLIT.bpe.$LANG" \
|
| 26 |
+
--workers 60 \
|
| 27 |
+
--keep-empty;
|
| 28 |
+
done
|
| 29 |
+
done
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### 3) Binarize dataset:
|
| 33 |
+
```bash
|
| 34 |
+
fairseq-preprocess \
|
| 35 |
+
--source-lang "source" \
|
| 36 |
+
--target-lang "target" \
|
| 37 |
+
--trainpref "${TASK}/train.bpe" \
|
| 38 |
+
--validpref "${TASK}/val.bpe" \
|
| 39 |
+
--destdir "${TASK}-bin/" \
|
| 40 |
+
--workers 60 \
|
| 41 |
+
--srcdict dict.txt \
|
| 42 |
+
--tgtdict dict.txt;
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
### 4) Fine-tuning on CNN-DM summarization task:
|
| 46 |
+
Example fine-tuning CNN-DM
|
| 47 |
+
```bash
|
| 48 |
+
TOTAL_NUM_UPDATES=20000
|
| 49 |
+
WARMUP_UPDATES=500
|
| 50 |
+
LR=3e-05
|
| 51 |
+
MAX_TOKENS=2048
|
| 52 |
+
UPDATE_FREQ=4
|
| 53 |
+
BART_PATH=/path/to/bart/model.pt
|
| 54 |
+
|
| 55 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
|
| 56 |
+
--restore-file $BART_PATH \
|
| 57 |
+
--max-tokens $MAX_TOKENS \
|
| 58 |
+
--task translation \
|
| 59 |
+
--source-lang source --target-lang target \
|
| 60 |
+
--truncate-source \
|
| 61 |
+
--layernorm-embedding \
|
| 62 |
+
--share-all-embeddings \
|
| 63 |
+
--share-decoder-input-output-embed \
|
| 64 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 65 |
+
--required-batch-size-multiple 1 \
|
| 66 |
+
--arch bart_large \
|
| 67 |
+
--criterion label_smoothed_cross_entropy \
|
| 68 |
+
--label-smoothing 0.1 \
|
| 69 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 70 |
+
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
|
| 71 |
+
--clip-norm 0.1 \
|
| 72 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
| 73 |
+
--fp16 --update-freq $UPDATE_FREQ \
|
| 74 |
+
--skip-invalid-size-inputs-valid-test \
|
| 75 |
+
--find-unused-parameters;
|
| 76 |
+
```
|
| 77 |
+
Above is expected to run on `1` node with `8 32gb-V100`.
|
| 78 |
+
Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
|
| 79 |
+
|
| 80 |
+
Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
|
| 81 |
+
|
| 82 |
+
### Inference for CNN-DM test data using above trained checkpoint.
|
| 83 |
+
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
|
| 84 |
+
|
| 85 |
+
```python
|
| 86 |
+
import torch
|
| 87 |
+
from fairseq.models.bart import BARTModel
|
| 88 |
+
|
| 89 |
+
bart = BARTModel.from_pretrained(
|
| 90 |
+
'checkpoints/',
|
| 91 |
+
checkpoint_file='checkpoint_best.pt',
|
| 92 |
+
data_name_or_path='cnn_dm-bin'
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
bart.cuda()
|
| 96 |
+
bart.eval()
|
| 97 |
+
bart.half()
|
| 98 |
+
count = 1
|
| 99 |
+
bsz = 32
|
| 100 |
+
with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout:
|
| 101 |
+
sline = source.readline().strip()
|
| 102 |
+
slines = [sline]
|
| 103 |
+
for sline in source:
|
| 104 |
+
if count % bsz == 0:
|
| 105 |
+
with torch.no_grad():
|
| 106 |
+
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
|
| 107 |
+
|
| 108 |
+
for hypothesis in hypotheses_batch:
|
| 109 |
+
fout.write(hypothesis + '\n')
|
| 110 |
+
fout.flush()
|
| 111 |
+
slines = []
|
| 112 |
+
|
| 113 |
+
slines.append(sline.strip())
|
| 114 |
+
count += 1
|
| 115 |
+
if slines != []:
|
| 116 |
+
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
|
| 117 |
+
for hypothesis in hypotheses_batch:
|
| 118 |
+
fout.write(hypothesis + '\n')
|
| 119 |
+
fout.flush()
|
| 120 |
+
```
|
| 121 |
+
Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation
|
fairseq-0.10.2/examples/criss/download_and_preprocess_flores_test.sh
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
SPM_ENCODE=flores/scripts/spm_encode.py
|
| 9 |
+
DATA=data_tmp
|
| 10 |
+
SPM_MODEL=criss_checkpoints/sentence.bpe.model
|
| 11 |
+
DICT=criss_checkpoints/dict.txt
|
| 12 |
+
|
| 13 |
+
download_data() {
|
| 14 |
+
CORPORA=$1
|
| 15 |
+
URL=$2
|
| 16 |
+
|
| 17 |
+
if [ -f $CORPORA ]; then
|
| 18 |
+
echo "$CORPORA already exists, skipping download"
|
| 19 |
+
else
|
| 20 |
+
echo "Downloading $URL"
|
| 21 |
+
wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA
|
| 22 |
+
if [ -f $CORPORA ]; then
|
| 23 |
+
echo "$URL successfully downloaded."
|
| 24 |
+
else
|
| 25 |
+
echo "$URL not successfully downloaded."
|
| 26 |
+
rm -f $CORPORA
|
| 27 |
+
fi
|
| 28 |
+
fi
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
if [[ -f flores ]]; then
|
| 32 |
+
echo "flores already cloned"
|
| 33 |
+
else
|
| 34 |
+
git clone https://github.com/facebookresearch/flores
|
| 35 |
+
fi
|
| 36 |
+
|
| 37 |
+
mkdir -p $DATA
|
| 38 |
+
download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz"
|
| 39 |
+
pushd $DATA
|
| 40 |
+
pwd
|
| 41 |
+
tar -vxf wikipedia_en_ne_si_test_sets.tgz
|
| 42 |
+
popd
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
for lang in ne_NP si_LK; do
|
| 46 |
+
datadir=$DATA/${lang}-en_XX-flores
|
| 47 |
+
rm -rf $datadir
|
| 48 |
+
mkdir -p $datadir
|
| 49 |
+
TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test
|
| 50 |
+
python $SPM_ENCODE \
|
| 51 |
+
--model ${SPM_MODEL} \
|
| 52 |
+
--output_format=piece \
|
| 53 |
+
--inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \
|
| 54 |
+
--outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
|
| 55 |
+
|
| 56 |
+
# binarize data
|
| 57 |
+
fairseq-preprocess \
|
| 58 |
+
--source-lang ${lang} --target-lang en_XX \
|
| 59 |
+
--testpref $datadir/test.bpe.${lang}-en_XX \
|
| 60 |
+
--destdir $datadir \
|
| 61 |
+
--srcdict ${DICT} \
|
| 62 |
+
--joined-dictionary \
|
| 63 |
+
--workers 4
|
| 64 |
+
done
|
fairseq-0.10.2/examples/criss/download_and_preprocess_tatoeba.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
SPM_ENCODE=flores/scripts/spm_encode.py
|
| 9 |
+
DATA=data_tmp
|
| 10 |
+
SPM_MODEL=criss_checkpoints/sentence.bpe.model
|
| 11 |
+
DICT=criss_checkpoints/dict.txt
|
| 12 |
+
|
| 13 |
+
git clone https://github.com/facebookresearch/LASER
|
| 14 |
+
mkdir -p data_tmp
|
| 15 |
+
declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn")
|
| 16 |
+
for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do
|
| 17 |
+
lang_tatoeba=${lang_tatoeba_map[$lang]}
|
| 18 |
+
echo $lang_tatoeba
|
| 19 |
+
datadir=$DATA/${lang}-en_XX-tatoeba
|
| 20 |
+
rm -rf $datadir
|
| 21 |
+
mkdir -p $datadir
|
| 22 |
+
TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba
|
| 23 |
+
python $SPM_ENCODE \
|
| 24 |
+
--model ${SPM_MODEL} \
|
| 25 |
+
--output_format=piece \
|
| 26 |
+
--inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \
|
| 27 |
+
--outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
|
| 28 |
+
|
| 29 |
+
# binarize data
|
| 30 |
+
fairseq-preprocess \
|
| 31 |
+
--source-lang ${lang} --target-lang en_XX \
|
| 32 |
+
--testpref $datadir/test.bpe.${lang}-en_XX \
|
| 33 |
+
--destdir $datadir \
|
| 34 |
+
--srcdict ${DICT} \
|
| 35 |
+
--joined-dictionary \
|
| 36 |
+
--workers 4
|
| 37 |
+
done
|
fairseq-0.10.2/examples/criss/mining/mine.py
ADDED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import argparse
|
| 7 |
+
import glob
|
| 8 |
+
from subprocess import check_call
|
| 9 |
+
|
| 10 |
+
import faiss
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
GB = 1024 * 1024 * 1024
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def call(cmd):
|
| 18 |
+
print(cmd)
|
| 19 |
+
check_call(cmd, shell=True)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_batches(directory, lang, prefix="all_avg_pool"):
|
| 23 |
+
print(f"Finding in {directory}/{prefix}.{lang}*")
|
| 24 |
+
files = glob.glob(f"{directory}/{prefix}.{lang}*")
|
| 25 |
+
emb_files = []
|
| 26 |
+
txt_files = []
|
| 27 |
+
for emb_fi in files:
|
| 28 |
+
emb_files.append(emb_fi)
|
| 29 |
+
txt_fi = emb_fi.replace(prefix, "sentences")
|
| 30 |
+
txt_files.append(txt_fi)
|
| 31 |
+
return emb_files, txt_files
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def load_batch(emb_file, dim):
|
| 35 |
+
embeddings = np.fromfile(emb_file, dtype=np.float32)
|
| 36 |
+
num_rows = int(embeddings.shape[0] / dim)
|
| 37 |
+
embeddings = embeddings.reshape((num_rows, dim))
|
| 38 |
+
faiss.normalize_L2(embeddings)
|
| 39 |
+
return embeddings
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
|
| 43 |
+
sims = []
|
| 44 |
+
inds = []
|
| 45 |
+
xfrom = 0
|
| 46 |
+
xto = 0
|
| 47 |
+
for x_batch_f in x_batches_f:
|
| 48 |
+
yfrom = 0
|
| 49 |
+
yto = 0
|
| 50 |
+
x_batch = load_batch(x_batch_f, dim)
|
| 51 |
+
xto = xfrom + x_batch.shape[0]
|
| 52 |
+
bsims, binds = [], []
|
| 53 |
+
for y_batch_f in y_batches_f:
|
| 54 |
+
y_batch = load_batch(y_batch_f, dim)
|
| 55 |
+
neighbor_size = min(k, y_batch.shape[0])
|
| 56 |
+
yto = yfrom + y_batch.shape[0]
|
| 57 |
+
print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
|
| 58 |
+
idx = faiss.IndexFlatIP(dim)
|
| 59 |
+
idx = faiss.index_cpu_to_all_gpus(idx)
|
| 60 |
+
idx.add(y_batch)
|
| 61 |
+
bsim, bind = idx.search(x_batch, neighbor_size)
|
| 62 |
+
|
| 63 |
+
bsims.append(bsim)
|
| 64 |
+
binds.append(bind + yfrom)
|
| 65 |
+
yfrom += y_batch.shape[0]
|
| 66 |
+
del idx
|
| 67 |
+
del y_batch
|
| 68 |
+
bsims = np.concatenate(bsims, axis=1)
|
| 69 |
+
binds = np.concatenate(binds, axis=1)
|
| 70 |
+
aux = np.argsort(-bsims, axis=1)
|
| 71 |
+
sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32)
|
| 72 |
+
ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64)
|
| 73 |
+
for i in range(x_batch.shape[0]):
|
| 74 |
+
for j in range(k):
|
| 75 |
+
sim_batch[i, j] = bsims[i, aux[i, j]]
|
| 76 |
+
ind_batch[i, j] = binds[i, aux[i, j]]
|
| 77 |
+
sims.append(sim_batch)
|
| 78 |
+
inds.append(ind_batch)
|
| 79 |
+
xfrom += x_batch.shape[0]
|
| 80 |
+
del x_batch
|
| 81 |
+
sim = np.concatenate(sims, axis=0)
|
| 82 |
+
ind = np.concatenate(inds, axis=0)
|
| 83 |
+
return sim, ind
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def score(sim, fwd_mean, bwd_mean, margin):
|
| 87 |
+
return margin(sim, (fwd_mean + bwd_mean) / 2)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def score_candidates(
|
| 91 |
+
sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
|
| 92 |
+
):
|
| 93 |
+
print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
|
| 94 |
+
scores = np.zeros(candidate_inds.shape)
|
| 95 |
+
for i in range(scores.shape[0]):
|
| 96 |
+
for j in range(scores.shape[1]):
|
| 97 |
+
k = int(candidate_inds[i, j])
|
| 98 |
+
scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin)
|
| 99 |
+
return scores
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def load_text(files):
|
| 103 |
+
all_sentences = []
|
| 104 |
+
for fi in files:
|
| 105 |
+
with open(fi) as sentence_fi:
|
| 106 |
+
for line in sentence_fi:
|
| 107 |
+
all_sentences.append(line.strip())
|
| 108 |
+
print(f"Read {len(all_sentences)} sentences")
|
| 109 |
+
return all_sentences
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
if __name__ == "__main__":
|
| 113 |
+
parser = argparse.ArgumentParser(description="Mine bitext")
|
| 114 |
+
parser.add_argument("--src-lang", help="Source language")
|
| 115 |
+
parser.add_argument("--tgt-lang", help="Target language")
|
| 116 |
+
parser.add_argument(
|
| 117 |
+
"--dict-path", help="Path to dictionary file", default="dict.txt"
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
|
| 123 |
+
parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
|
| 124 |
+
parser.add_argument("--src-dir", help="Source directory")
|
| 125 |
+
parser.add_argument("--tgt-dir", help="Target directory")
|
| 126 |
+
parser.add_argument("--output", help="Output path")
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--neighborhood", type=int, default=4, help="Embedding dimension"
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--threshold", type=float, default=1.06, help="Threshold on mined bitext"
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--valid-size",
|
| 135 |
+
type=int,
|
| 136 |
+
default=2000,
|
| 137 |
+
help="Number of sentences used for validation set",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--min-count",
|
| 141 |
+
type=int,
|
| 142 |
+
default=50000,
|
| 143 |
+
help="Min num sentences used for each language",
|
| 144 |
+
)
|
| 145 |
+
args = parser.parse_args()
|
| 146 |
+
|
| 147 |
+
x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
|
| 148 |
+
y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
|
| 149 |
+
margin = lambda a, b: a / b
|
| 150 |
+
y2x_sim, y2x_ind = knnGPU_sharded(
|
| 151 |
+
y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
|
| 152 |
+
)
|
| 153 |
+
x2y_sim, x2y_ind = knnGPU_sharded(
|
| 154 |
+
x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
x2y_mean = x2y_sim.mean(axis=1)
|
| 158 |
+
y2x_mean = y2x_sim.mean(axis=1)
|
| 159 |
+
fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin)
|
| 160 |
+
bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
|
| 161 |
+
fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
|
| 162 |
+
bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
|
| 163 |
+
indices = np.stack(
|
| 164 |
+
(
|
| 165 |
+
np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
|
| 166 |
+
np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
|
| 167 |
+
),
|
| 168 |
+
axis=1,
|
| 169 |
+
)
|
| 170 |
+
scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
|
| 171 |
+
|
| 172 |
+
x_sentences = load_text(x_sents_f)
|
| 173 |
+
y_sentences = load_text(y_sents_f)
|
| 174 |
+
|
| 175 |
+
threshold = args.threshold
|
| 176 |
+
min_count = args.min_count
|
| 177 |
+
seen_src, seen_trg = set(), set()
|
| 178 |
+
directory = args.output
|
| 179 |
+
call(f"mkdir -p {directory}")
|
| 180 |
+
src_out = open(
|
| 181 |
+
f"{directory}/all.{args.src_lang}",
|
| 182 |
+
mode="w",
|
| 183 |
+
encoding="utf-8",
|
| 184 |
+
errors="surrogateescape",
|
| 185 |
+
)
|
| 186 |
+
tgt_out = open(
|
| 187 |
+
f"{directory}/all.{args.tgt_lang}",
|
| 188 |
+
mode="w",
|
| 189 |
+
encoding="utf-8",
|
| 190 |
+
errors="surrogateescape",
|
| 191 |
+
)
|
| 192 |
+
scores_out = open(
|
| 193 |
+
f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
|
| 194 |
+
)
|
| 195 |
+
count = 0
|
| 196 |
+
for i in np.argsort(-scores):
|
| 197 |
+
src_ind, trg_ind = indices[i]
|
| 198 |
+
if src_ind not in seen_src and trg_ind not in seen_trg:
|
| 199 |
+
seen_src.add(src_ind)
|
| 200 |
+
seen_trg.add(trg_ind)
|
| 201 |
+
if scores[i] > threshold or count < min_count:
|
| 202 |
+
if x_sentences[src_ind]:
|
| 203 |
+
print(scores[i], file=scores_out)
|
| 204 |
+
print(x_sentences[src_ind], file=src_out)
|
| 205 |
+
print(y_sentences[trg_ind], file=tgt_out)
|
| 206 |
+
count += 1
|
| 207 |
+
else:
|
| 208 |
+
print(f"Ignoring sentence: {x_sentences[src_ind]}")
|
| 209 |
+
src_out.close()
|
| 210 |
+
tgt_out.close()
|
| 211 |
+
scores_out.close()
|
| 212 |
+
|
| 213 |
+
print(f"Found {count} pairs for threshold={threshold}")
|
| 214 |
+
with open(f"{directory}/all.{args.src_lang}") as all_s, open(
|
| 215 |
+
f"{directory}/all.{args.tgt_lang}"
|
| 216 |
+
) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
|
| 217 |
+
f"{directory}/valid.{args.tgt_lang}", "w"
|
| 218 |
+
) as valid_t, open(
|
| 219 |
+
f"{directory}/train.{args.src_lang}", "w"
|
| 220 |
+
) as train_s, open(
|
| 221 |
+
f"{directory}/train.{args.tgt_lang}", "w"
|
| 222 |
+
) as train_t:
|
| 223 |
+
count = 0
|
| 224 |
+
for s_line, t_line in zip(all_s, all_t):
|
| 225 |
+
s_line = s_line.split("\t")[1]
|
| 226 |
+
t_line = t_line.split("\t")[1]
|
| 227 |
+
if count >= args.valid_size:
|
| 228 |
+
train_s.write(s_line)
|
| 229 |
+
train_t.write(t_line)
|
| 230 |
+
else:
|
| 231 |
+
valid_s.write(s_line)
|
| 232 |
+
valid_t.write(t_line)
|
| 233 |
+
count += 1
|
fairseq-0.10.2/examples/criss/mining/mine_example.sh
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
#
|
| 8 |
+
source_lang=kk_KZ
|
| 9 |
+
target_lang=en_XX
|
| 10 |
+
MODEL=criss_checkpoints/criss.2nd.pt
|
| 11 |
+
SPM=criss_checkpoints/sentence.bpe.model
|
| 12 |
+
SPLIT=test
|
| 13 |
+
LANG_DICT=criss_checkpoints/lang_dict.txt
|
| 14 |
+
SPM_ENCODE=flores/scripts/spm_encode.py
|
| 15 |
+
SAVE_ENCODER=save_encoder.py
|
| 16 |
+
ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
|
| 17 |
+
DICT=criss_checkpoints/dict.txt
|
| 18 |
+
THRESHOLD=1.02
|
| 19 |
+
MIN_COUNT=500
|
| 20 |
+
|
| 21 |
+
DATA_DIR=data_tmp
|
| 22 |
+
SAVE_DIR=mining/${source_lang}_${target_lang}_mined
|
| 23 |
+
ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
|
| 24 |
+
INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
|
| 25 |
+
|
| 26 |
+
mkdir -p $ENCODER_SAVE_DIR/${target_lang}
|
| 27 |
+
mkdir -p $ENCODER_SAVE_DIR/${source_lang}
|
| 28 |
+
mkdir -p $SAVE_DIR
|
| 29 |
+
|
| 30 |
+
## Save encoder outputs
|
| 31 |
+
|
| 32 |
+
# Save encoder outputs for source sentences
|
| 33 |
+
python $SAVE_ENCODER \
|
| 34 |
+
${INPUT_DIR} \
|
| 35 |
+
--path ${MODEL} \
|
| 36 |
+
--task translation_multi_simple_epoch \
|
| 37 |
+
--lang-pairs ${source_lang}-${target_lang} \
|
| 38 |
+
--lang-dict ${LANG_DICT} \
|
| 39 |
+
--gen-subset ${SPLIT} \
|
| 40 |
+
--bpe 'sentencepiece' \
|
| 41 |
+
-s ${source_lang} -t ${target_lang} \
|
| 42 |
+
--sentencepiece-model ${SPM} \
|
| 43 |
+
--remove-bpe 'sentencepiece' \
|
| 44 |
+
--beam 1 \
|
| 45 |
+
--lang-tok-style mbart \
|
| 46 |
+
--encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
|
| 47 |
+
|
| 48 |
+
## Save encoder outputs for target sentences
|
| 49 |
+
python $SAVE_ENCODER \
|
| 50 |
+
${INPUT_DIR} \
|
| 51 |
+
--path ${MODEL} \
|
| 52 |
+
--lang-pairs ${source_lang}-${target_lang} \
|
| 53 |
+
--lang-dict ${LANG_DICT} \
|
| 54 |
+
--task translation_multi_simple_epoch \
|
| 55 |
+
--gen-subset ${SPLIT} \
|
| 56 |
+
--bpe 'sentencepiece' \
|
| 57 |
+
-t ${source_lang} -s ${target_lang} \
|
| 58 |
+
--sentencepiece-model ${SPM} \
|
| 59 |
+
--remove-bpe 'sentencepiece' \
|
| 60 |
+
--beam 1 \
|
| 61 |
+
--lang-tok-style mbart \
|
| 62 |
+
--encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
|
| 63 |
+
|
| 64 |
+
## Mining
|
| 65 |
+
python mining/mine.py \
|
| 66 |
+
--src-lang ${source_lang} \
|
| 67 |
+
--tgt-lang ${target_lang} \
|
| 68 |
+
--dim 1024 \
|
| 69 |
+
--mem 10 \
|
| 70 |
+
--neighborhood 4 \
|
| 71 |
+
--src-dir ${ENCODER_SAVE_DIR}/${source_lang} \
|
| 72 |
+
--tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \
|
| 73 |
+
--output $SAVE_DIR \
|
| 74 |
+
--threshold ${THRESHOLD} \
|
| 75 |
+
--min-count ${MIN_COUNT} \
|
| 76 |
+
--valid-size 100 \
|
| 77 |
+
--dict-path ${DICT} \
|
| 78 |
+
--spm-path ${SPM} \
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
## Process and binarize mined data
|
| 82 |
+
python $SPM_ENCODE \
|
| 83 |
+
--model ${SPM} \
|
| 84 |
+
--output_format=piece \
|
| 85 |
+
--inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \
|
| 86 |
+
--outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang}
|
| 87 |
+
|
| 88 |
+
python $SPM_ENCODE \
|
| 89 |
+
--model ${SPM} \
|
| 90 |
+
--output_format=piece \
|
| 91 |
+
--inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \
|
| 92 |
+
--outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
fairseq-preprocess \
|
| 96 |
+
--source-lang ${source_lang} \
|
| 97 |
+
--target-lang ${target_lang} \
|
| 98 |
+
--trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \
|
| 99 |
+
--validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \
|
| 100 |
+
--destdir mining/${source_lang}_${target_lang}_mined \
|
| 101 |
+
--srcdict ${DICT} \
|
| 102 |
+
--joined-dictionary \
|
| 103 |
+
--workers 8
|
fairseq-0.10.2/examples/criss/save_encoder.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
"""
|
| 7 |
+
Translate pre-processed data with a trained model.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
| 13 |
+
from fairseq.sequence_generator import EnsembleModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_avg_pool(
|
| 17 |
+
models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
|
| 18 |
+
):
|
| 19 |
+
model = EnsembleModel(models)
|
| 20 |
+
|
| 21 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
| 22 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
| 23 |
+
encoder_input = {
|
| 24 |
+
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
# compute the encoder output for each beam
|
| 28 |
+
encoder_outs = model.forward_encoder(encoder_input)
|
| 29 |
+
np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
|
| 30 |
+
encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
|
| 31 |
+
np.float32
|
| 32 |
+
)
|
| 33 |
+
encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
|
| 34 |
+
if has_langtok:
|
| 35 |
+
encoder_mask = encoder_mask[1:, :, :]
|
| 36 |
+
np_encoder_outs = np_encoder_outs[1, :, :]
|
| 37 |
+
masked_encoder_outs = encoder_mask * np_encoder_outs
|
| 38 |
+
avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
|
| 39 |
+
return avg_pool
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def main(args):
|
| 43 |
+
assert args.path is not None, "--path required for generation!"
|
| 44 |
+
assert (
|
| 45 |
+
not args.sampling or args.nbest == args.beam
|
| 46 |
+
), "--sampling requires --nbest to be equal to --beam"
|
| 47 |
+
assert (
|
| 48 |
+
args.replace_unk is None or args.raw_text
|
| 49 |
+
), "--replace-unk requires a raw text dataset (--raw-text)"
|
| 50 |
+
|
| 51 |
+
args.beam = 1
|
| 52 |
+
utils.import_user_module(args)
|
| 53 |
+
|
| 54 |
+
if args.max_tokens is None:
|
| 55 |
+
args.max_tokens = 12000
|
| 56 |
+
print(args)
|
| 57 |
+
use_cuda = torch.cuda.is_available() and not args.cpu
|
| 58 |
+
|
| 59 |
+
# Load dataset splits
|
| 60 |
+
task = tasks.setup_task(args)
|
| 61 |
+
task.load_dataset(args.gen_subset)
|
| 62 |
+
|
| 63 |
+
# Set dictionaries
|
| 64 |
+
try:
|
| 65 |
+
src_dict = getattr(task, "source_dictionary", None)
|
| 66 |
+
except NotImplementedError:
|
| 67 |
+
src_dict = None
|
| 68 |
+
tgt_dict = task.target_dictionary
|
| 69 |
+
|
| 70 |
+
# Load ensemble
|
| 71 |
+
print("| loading model(s) from {}".format(args.path))
|
| 72 |
+
models, _model_args = checkpoint_utils.load_model_ensemble(
|
| 73 |
+
args.path.split(":"),
|
| 74 |
+
arg_overrides=eval(args.model_overrides),
|
| 75 |
+
task=task,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# Optimize ensemble for generation
|
| 79 |
+
for model in models:
|
| 80 |
+
model.make_generation_fast_(
|
| 81 |
+
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
| 82 |
+
need_attn=args.print_alignment,
|
| 83 |
+
)
|
| 84 |
+
if args.fp16:
|
| 85 |
+
model.half()
|
| 86 |
+
if use_cuda:
|
| 87 |
+
model.cuda()
|
| 88 |
+
|
| 89 |
+
# Load alignment dictionary for unknown word replacement
|
| 90 |
+
# (None if no unknown word replacement, empty if no path to align dictionary)
|
| 91 |
+
align_dict = utils.load_align_dict(args.replace_unk)
|
| 92 |
+
|
| 93 |
+
# Load dataset (possibly sharded)
|
| 94 |
+
itr = task.get_batch_iterator(
|
| 95 |
+
dataset=task.dataset(args.gen_subset),
|
| 96 |
+
max_tokens=args.max_tokens,
|
| 97 |
+
max_positions=utils.resolve_max_positions(
|
| 98 |
+
task.max_positions(),
|
| 99 |
+
),
|
| 100 |
+
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
| 101 |
+
required_batch_size_multiple=args.required_batch_size_multiple,
|
| 102 |
+
num_shards=args.num_shards,
|
| 103 |
+
shard_id=args.shard_id,
|
| 104 |
+
num_workers=args.num_workers,
|
| 105 |
+
).next_epoch_itr(shuffle=False)
|
| 106 |
+
|
| 107 |
+
num_sentences = 0
|
| 108 |
+
source_sentences = []
|
| 109 |
+
shard_id = 0
|
| 110 |
+
all_avg_pool = None
|
| 111 |
+
encoder_has_langtok = (
|
| 112 |
+
hasattr(task.args, "encoder_langtok")
|
| 113 |
+
and task.args.encoder_langtok is not None
|
| 114 |
+
and hasattr(task.args, "lang_tok_replacing_bos_eos")
|
| 115 |
+
and not task.args.lang_tok_replacing_bos_eos
|
| 116 |
+
)
|
| 117 |
+
with progress_bar.build_progress_bar(args, itr) as t:
|
| 118 |
+
for sample in t:
|
| 119 |
+
if sample is None:
|
| 120 |
+
print("Skipping None")
|
| 121 |
+
continue
|
| 122 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
| 123 |
+
if "net_input" not in sample:
|
| 124 |
+
continue
|
| 125 |
+
|
| 126 |
+
prefix_tokens = None
|
| 127 |
+
if args.prefix_size > 0:
|
| 128 |
+
prefix_tokens = sample["target"][:, : args.prefix_size]
|
| 129 |
+
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
avg_pool = get_avg_pool(
|
| 132 |
+
models,
|
| 133 |
+
sample,
|
| 134 |
+
prefix_tokens,
|
| 135 |
+
src_dict,
|
| 136 |
+
args.remove_bpe,
|
| 137 |
+
has_langtok=encoder_has_langtok,
|
| 138 |
+
)
|
| 139 |
+
if all_avg_pool is not None:
|
| 140 |
+
all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
|
| 141 |
+
else:
|
| 142 |
+
all_avg_pool = avg_pool
|
| 143 |
+
|
| 144 |
+
if not isinstance(sample["id"], list):
|
| 145 |
+
sample_ids = sample["id"].tolist()
|
| 146 |
+
else:
|
| 147 |
+
sample_ids = sample["id"]
|
| 148 |
+
for i, sample_id in enumerate(sample_ids):
|
| 149 |
+
# Remove padding
|
| 150 |
+
src_tokens = utils.strip_pad(
|
| 151 |
+
sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Either retrieve the original sentences or regenerate them from tokens.
|
| 155 |
+
if align_dict is not None:
|
| 156 |
+
src_str = task.dataset(args.gen_subset).src.get_original_text(
|
| 157 |
+
sample_id
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
if src_dict is not None:
|
| 161 |
+
src_str = src_dict.string(src_tokens, args.remove_bpe)
|
| 162 |
+
else:
|
| 163 |
+
src_str = ""
|
| 164 |
+
|
| 165 |
+
if not args.quiet:
|
| 166 |
+
if src_dict is not None:
|
| 167 |
+
print("S-{}\t{}".format(sample_id, src_str))
|
| 168 |
+
|
| 169 |
+
source_sentences.append(f"{sample_id}\t{src_str}")
|
| 170 |
+
|
| 171 |
+
num_sentences += sample["nsentences"]
|
| 172 |
+
if all_avg_pool.shape[0] >= 1000000:
|
| 173 |
+
with open(
|
| 174 |
+
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
|
| 175 |
+
"w",
|
| 176 |
+
) as avg_pool_file:
|
| 177 |
+
all_avg_pool.tofile(avg_pool_file)
|
| 178 |
+
with open(
|
| 179 |
+
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
|
| 180 |
+
"w",
|
| 181 |
+
) as sentence_file:
|
| 182 |
+
sentence_file.writelines(f"{line}\n" for line in source_sentences)
|
| 183 |
+
all_avg_pool = None
|
| 184 |
+
source_sentences = []
|
| 185 |
+
shard_id += 1
|
| 186 |
+
|
| 187 |
+
if all_avg_pool is not None:
|
| 188 |
+
with open(
|
| 189 |
+
f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
|
| 190 |
+
) as avg_pool_file:
|
| 191 |
+
all_avg_pool.tofile(avg_pool_file)
|
| 192 |
+
with open(
|
| 193 |
+
f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
|
| 194 |
+
) as sentence_file:
|
| 195 |
+
sentence_file.writelines(f"{line}\n" for line in source_sentences)
|
| 196 |
+
return None
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def cli_main():
|
| 200 |
+
parser = options.get_generation_parser()
|
| 201 |
+
parser.add_argument(
|
| 202 |
+
"--encoder-save-dir",
|
| 203 |
+
default="",
|
| 204 |
+
type=str,
|
| 205 |
+
metavar="N",
|
| 206 |
+
help="directory to save encoder outputs",
|
| 207 |
+
)
|
| 208 |
+
args = options.parse_args_and_arch(parser)
|
| 209 |
+
main(args)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
if __name__ == "__main__":
|
| 213 |
+
cli_main()
|
fairseq-0.10.2/examples/criss/sentence_retrieval/encoder_analysis.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
import argparse
|
| 7 |
+
import glob
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
DIM = 1024
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
|
| 16 |
+
target_ids = [tid for tid in target_embs]
|
| 17 |
+
source_mat = np.stack(source_embs.values(), axis=0)
|
| 18 |
+
normalized_source_mat = source_mat / np.linalg.norm(
|
| 19 |
+
source_mat, axis=1, keepdims=True
|
| 20 |
+
)
|
| 21 |
+
target_mat = np.stack(target_embs.values(), axis=0)
|
| 22 |
+
normalized_target_mat = target_mat / np.linalg.norm(
|
| 23 |
+
target_mat, axis=1, keepdims=True
|
| 24 |
+
)
|
| 25 |
+
sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
|
| 26 |
+
if return_sim_mat:
|
| 27 |
+
return sim_mat
|
| 28 |
+
neighbors_map = {}
|
| 29 |
+
for i, sentence_id in enumerate(source_embs):
|
| 30 |
+
idx = np.argsort(sim_mat[i, :])[::-1][:k]
|
| 31 |
+
neighbors_map[sentence_id] = [target_ids[tid] for tid in idx]
|
| 32 |
+
return neighbors_map
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_embeddings(directory, LANGS):
|
| 36 |
+
sentence_embeddings = {}
|
| 37 |
+
sentence_texts = {}
|
| 38 |
+
for lang in LANGS:
|
| 39 |
+
sentence_embeddings[lang] = {}
|
| 40 |
+
sentence_texts[lang] = {}
|
| 41 |
+
lang_dir = f"{directory}/{lang}"
|
| 42 |
+
embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
|
| 43 |
+
for embed_file in embedding_files:
|
| 44 |
+
shard_id = embed_file.split(".")[-1]
|
| 45 |
+
embeddings = np.fromfile(embed_file, dtype=np.float32)
|
| 46 |
+
num_rows = embeddings.shape[0] // DIM
|
| 47 |
+
embeddings = embeddings.reshape((num_rows, DIM))
|
| 48 |
+
|
| 49 |
+
with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
|
| 50 |
+
for idx, line in enumerate(sentence_file):
|
| 51 |
+
sentence_id, sentence = line.strip().split("\t")
|
| 52 |
+
sentence_texts[lang][sentence_id] = sentence
|
| 53 |
+
sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
|
| 54 |
+
|
| 55 |
+
return sentence_embeddings, sentence_texts
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def compute_accuracy(directory, LANGS):
|
| 59 |
+
sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS)
|
| 60 |
+
|
| 61 |
+
top_1_accuracy = {}
|
| 62 |
+
|
| 63 |
+
top1_str = " ".join(LANGS) + "\n"
|
| 64 |
+
for source_lang in LANGS:
|
| 65 |
+
top_1_accuracy[source_lang] = {}
|
| 66 |
+
top1_str += f"{source_lang} "
|
| 67 |
+
for target_lang in LANGS:
|
| 68 |
+
top1 = 0
|
| 69 |
+
top5 = 0
|
| 70 |
+
neighbors_map = compute_dist(
|
| 71 |
+
sentence_embeddings[source_lang], sentence_embeddings[target_lang]
|
| 72 |
+
)
|
| 73 |
+
for sentence_id, neighbors in neighbors_map.items():
|
| 74 |
+
if sentence_id == neighbors[0]:
|
| 75 |
+
top1 += 1
|
| 76 |
+
if sentence_id in neighbors[:5]:
|
| 77 |
+
top5 += 1
|
| 78 |
+
n = len(sentence_embeddings[target_lang])
|
| 79 |
+
top1_str += f"{top1/n} "
|
| 80 |
+
top1_str += "\n"
|
| 81 |
+
|
| 82 |
+
print(top1_str)
|
| 83 |
+
print(top1_str, file=open(f"{directory}/accuracy", "w"))
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
parser = argparse.ArgumentParser(description="Analyze encoder outputs")
|
| 88 |
+
parser.add_argument("directory", help="Source language corpus")
|
| 89 |
+
parser.add_argument("--langs", help="List of langs")
|
| 90 |
+
args = parser.parse_args()
|
| 91 |
+
langs = args.langs.split(",")
|
| 92 |
+
compute_accuracy(args.directory, langs)
|
fairseq-0.10.2/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
#
|
| 8 |
+
source_lang=kk_KZ
|
| 9 |
+
target_lang=en_XX
|
| 10 |
+
MODEL=criss_checkpoints/criss.3rd.pt
|
| 11 |
+
SPM=criss_checkpoints/sentence.bpe.model
|
| 12 |
+
SPLIT=test
|
| 13 |
+
LANG_DICT=criss_checkpoints/lang_dict.txt
|
| 14 |
+
ENCODER_ANALYSIS=sentence_retrieval/encoder_analysis.py
|
| 15 |
+
SAVE_ENCODER=save_encoder.py
|
| 16 |
+
ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
DATA_DIR=data_tmp
|
| 21 |
+
INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
|
| 22 |
+
ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
|
| 23 |
+
mkdir -p $ENCODER_SAVE_DIR/${target_lang}
|
| 24 |
+
mkdir -p $ENCODER_SAVE_DIR/${source_lang}
|
| 25 |
+
|
| 26 |
+
# Save encoder outputs for source sentences
|
| 27 |
+
python $SAVE_ENCODER \
|
| 28 |
+
${INPUT_DIR} \
|
| 29 |
+
--path ${MODEL} \
|
| 30 |
+
--task translation_multi_simple_epoch \
|
| 31 |
+
--lang-dict ${LANG_DICT} \
|
| 32 |
+
--gen-subset ${SPLIT} \
|
| 33 |
+
--bpe 'sentencepiece' \
|
| 34 |
+
--lang-pairs ${source_lang}-${target_lang} \
|
| 35 |
+
-s ${source_lang} -t ${target_lang} \
|
| 36 |
+
--sentencepiece-model ${SPM} \
|
| 37 |
+
--remove-bpe 'sentencepiece' \
|
| 38 |
+
--beam 1 \
|
| 39 |
+
--lang-tok-style mbart \
|
| 40 |
+
--encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
|
| 41 |
+
|
| 42 |
+
# Save encoder outputs for target sentences
|
| 43 |
+
python $SAVE_ENCODER \
|
| 44 |
+
${INPUT_DIR} \
|
| 45 |
+
--path ${MODEL} \
|
| 46 |
+
--lang-dict ${LANG_DICT} \
|
| 47 |
+
--task translation_multi_simple_epoch \
|
| 48 |
+
--gen-subset ${SPLIT} \
|
| 49 |
+
--bpe 'sentencepiece' \
|
| 50 |
+
--lang-pairs ${target_lang}-${source_lang} \
|
| 51 |
+
-t ${source_lang} -s ${target_lang} \
|
| 52 |
+
--sentencepiece-model ${SPM} \
|
| 53 |
+
--remove-bpe 'sentencepiece' \
|
| 54 |
+
--beam 1 \
|
| 55 |
+
--lang-tok-style mbart \
|
| 56 |
+
--encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
|
| 57 |
+
|
| 58 |
+
# Analyze sentence retrieval accuracy
|
| 59 |
+
python $ENCODER_ANALYSIS --langs "${source_lang},${target_lang}" ${ENCODER_SAVE_DIR}
|
fairseq-0.10.2/examples/criss/unsupervised_mt/eval.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
#
|
| 8 |
+
SRC=si_LK
|
| 9 |
+
TGT=en_XX
|
| 10 |
+
MODEL=criss_checkpoints/criss.3rd.pt
|
| 11 |
+
|
| 12 |
+
MULTIBLEU=mosesdecoder/scripts/generic/multi-bleu.perl
|
| 13 |
+
MOSES=mosesdecoder
|
| 14 |
+
REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
|
| 15 |
+
NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
|
| 16 |
+
REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
|
| 17 |
+
TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
|
| 18 |
+
GEN_TMP_DIR=gen_tmp
|
| 19 |
+
LANG_DICT=criss_checkpoints/lang_dict.txt
|
| 20 |
+
|
| 21 |
+
if [ ! -d "mosesdecoder" ]; then
|
| 22 |
+
git clone https://github.com/moses-smt/mosesdecoder
|
| 23 |
+
fi
|
| 24 |
+
mkdir -p $GEN_TMP_DIR
|
| 25 |
+
fairseq-generate data_tmp/${SRC}-${TGT}-flores \
|
| 26 |
+
--task translation_multi_simple_epoch \
|
| 27 |
+
--max-tokens 2000 \
|
| 28 |
+
--path ${MODEL} \
|
| 29 |
+
--skip-invalid-size-inputs-valid-test \
|
| 30 |
+
--beam 5 --lenpen 1.0 --gen-subset test \
|
| 31 |
+
--remove-bpe=sentencepiece \
|
| 32 |
+
--source-lang ${SRC} --target-lang ${TGT} \
|
| 33 |
+
--decoder-langtok --lang-pairs 'en_XX-ar_AR,en_XX-de_DE,en_XX-es_XX,en_XX-fr_XX,en_XX-hi_IN,en_XX-it_IT,en_XX-ja_XX,en_XX-ko_KR,en_XX-nl_XX,en_XX-ru_RU,en_XX-zh_CN,en_XX-tr_TR,en_XX-vi_VN,en_XX-ro_RO,en_XX-my_MM,en_XX-ne_NP,en_XX-si_LK,en_XX-cs_CZ,en_XX-lt_LT,en_XX-kk_KZ,en_XX-gu_IN,en_XX-fi_FI,en_XX-et_EE,en_XX-lv_LV,ar_AR-en_XX,cs_CZ-en_XX,de_DE-en_XX,es_XX-en_XX,et_EE-en_XX,fi_FI-en_XX,fr_XX-en_XX,gu_IN-en_XX,hi_IN-en_XX,it_IT-en_XX,ja_XX-en_XX,kk_KZ-en_XX,ko_KR-en_XX,lt_LT-en_XX,lv_LV-en_XX,my_MM-en_XX,ne_NP-en_XX,nl_XX-en_XX,ro_RO-en_XX,ru_RU-en_XX,si_LK-en_XX,tr_TR-en_XX,vi_VN-en_XX,zh_CN-en_XX,ar_AR-es_XX,es_XX-ar_AR,ar_AR-hi_IN,hi_IN-ar_AR,ar_AR-zh_CN,zh_CN-ar_AR,cs_CZ-es_XX,es_XX-cs_CZ,cs_CZ-hi_IN,hi_IN-cs_CZ,cs_CZ-zh_CN,zh_CN-cs_CZ,de_DE-es_XX,es_XX-de_DE,de_DE-hi_IN,hi_IN-de_DE,de_DE-zh_CN,zh_CN-de_DE,es_XX-hi_IN,hi_IN-es_XX,es_XX-zh_CN,zh_CN-es_XX,et_EE-es_XX,es_XX-et_EE,et_EE-hi_IN,hi_IN-et_EE,et_EE-zh_CN,zh_CN-et_EE,fi_FI-es_XX,es_XX-fi_FI,fi_FI-hi_IN,hi_IN-fi_FI,fi_FI-zh_CN,zh_CN-fi_FI,fr_XX-es_XX,es_XX-fr_XX,fr_XX-hi_IN,hi_IN-fr_XX,fr_XX-zh_CN,zh_CN-fr_XX,gu_IN-es_XX,es_XX-gu_IN,gu_IN-hi_IN,hi_IN-gu_IN,gu_IN-zh_CN,zh_CN-gu_IN,hi_IN-zh_CN,zh_CN-hi_IN,it_IT-es_XX,es_XX-it_IT,it_IT-hi_IN,hi_IN-it_IT,it_IT-zh_CN,zh_CN-it_IT,ja_XX-es_XX,es_XX-ja_XX,ja_XX-hi_IN,hi_IN-ja_XX,ja_XX-zh_CN,zh_CN-ja_XX,kk_KZ-es_XX,es_XX-kk_KZ,kk_KZ-hi_IN,hi_IN-kk_KZ,kk_KZ-zh_CN,zh_CN-kk_KZ,ko_KR-es_XX,es_XX-ko_KR,ko_KR-hi_IN,hi_IN-ko_KR,ko_KR-zh_CN,zh_CN-ko_KR,lt_LT-es_XX,es_XX-lt_LT,lt_LT-hi_IN,hi_IN-lt_LT,lt_LT-zh_CN,zh_CN-lt_LT,lv_LV-es_XX,es_XX-lv_LV,lv_LV-hi_IN,hi_IN-lv_LV,lv_LV-zh_CN,zh_CN-lv_LV,my_MM-es_XX,es_XX-my_MM,my_MM-hi_IN,hi_IN-my_MM,my_MM-zh_CN,zh_CN-my_MM,ne_NP-es_XX,es_XX-ne_NP,ne_NP-hi_IN,hi_IN-ne_NP,ne_NP-zh_CN,zh_CN-ne_NP,nl_XX-es_XX,es_XX-nl_XX,nl_XX-hi_IN,hi_IN-nl_XX,nl_XX-zh_CN,zh_CN-nl_XX,ro_RO-es_XX,es_XX-ro_RO,ro_RO-hi_IN,hi_IN-ro_RO,ro_RO-zh_CN,zh_CN-ro_RO,ru_RU-es_XX,es_XX-ru_RU,ru_RU-hi_IN,hi_IN-ru_RU,ru_RU-zh_CN,zh_CN-ru_RU,si_LK-es_XX,es_XX-si_LK,si_LK-hi_IN,hi_IN-si_LK,si_LK-zh_CN,zh_CN-si_LK,tr_TR-es_XX,es_XX-tr_TR,tr_TR-hi_IN,hi_IN-tr_TR,tr_TR-zh_CN,zh_CN-tr_TR,vi_VN-es_XX,es_XX-vi_VN,vi_VN-hi_IN,hi_IN-vi_VN,vi_VN-zh_CN,zh_CN-vi_VN' \
|
| 34 |
+
--lang-dict ${LANG_DICT} --lang-tok-style 'mbart' --sampling-method 'temperature' --sampling-temperature '1.0' > $GEN_TMP_DIR/${SRC}_${TGT}.gen
|
| 35 |
+
cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^T-" | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.hyp
|
| 36 |
+
cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^H-" | cut -f3 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.ref
|
| 37 |
+
${MULTIBLEU} $GEN_TMP_DIR/${SRC}_${TGT}.ref < $GEN_TMP_DIR/${SRC}_${TGT}.hyp
|
fairseq-0.10.2/examples/cross_lingual_language_model/README.md
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Cross-Lingual Language Model Pre-training
|
| 2 |
+
|
| 3 |
+
Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above.
|
| 4 |
+
|
| 5 |
+
## Downloading and Tokenizing Monolingual Data
|
| 6 |
+
|
| 7 |
+
Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data).
|
| 8 |
+
|
| 9 |
+
Let's assume the following for the code snippets in later sections to work
|
| 10 |
+
- Processed data is in the folder: monolingual_data/processed
|
| 11 |
+
- Each language has 3 files for train, test and validation. For example we have the following files for English:
|
| 12 |
+
train.en, valid.en
|
| 13 |
+
- We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr)
|
| 14 |
+
- The vocabulary file is monolingual_data/processed/vocab_mlm
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Fairseq Pre-processing and Binarization
|
| 18 |
+
|
| 19 |
+
Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
# Ensure the output directory exists
|
| 23 |
+
DATA_DIR=monolingual_data/fairseq_processed
|
| 24 |
+
mkdir -p "$DATA_DIR"
|
| 25 |
+
|
| 26 |
+
for lg in ar de en hi fr
|
| 27 |
+
do
|
| 28 |
+
|
| 29 |
+
fairseq-preprocess \
|
| 30 |
+
--task cross_lingual_lm \
|
| 31 |
+
--srcdict monolingual_data/processed/vocab_mlm \
|
| 32 |
+
--only-source \
|
| 33 |
+
--trainpref monolingual_data/processed/train \
|
| 34 |
+
--validpref monolingual_data/processed/valid \
|
| 35 |
+
--testpref monolingual_data/processed/test \
|
| 36 |
+
--destdir monolingual_data/fairseq_processed \
|
| 37 |
+
--workers 20 \
|
| 38 |
+
--source-lang $lg
|
| 39 |
+
|
| 40 |
+
# Since we only have a source language, the output file has a None for the
|
| 41 |
+
# target language. Remove this
|
| 42 |
+
|
| 43 |
+
for stage in train test valid
|
| 44 |
+
|
| 45 |
+
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin"
|
| 46 |
+
sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx"
|
| 47 |
+
|
| 48 |
+
done
|
| 49 |
+
|
| 50 |
+
done
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Train a Cross-lingual Language Model similar to the XLM MLM model
|
| 54 |
+
|
| 55 |
+
Use the following command to train the model on 5 languages.
|
| 56 |
+
|
| 57 |
+
```
|
| 58 |
+
fairseq-train \
|
| 59 |
+
--task cross_lingual_lm monolingual_data/fairseq_processed \
|
| 60 |
+
--save-dir checkpoints/mlm \
|
| 61 |
+
--max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
|
| 62 |
+
--arch xlm_base \
|
| 63 |
+
--optimizer adam --lr-scheduler reduce_lr_on_plateau \
|
| 64 |
+
--lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
|
| 65 |
+
--dropout 0.1 \
|
| 66 |
+
--criterion legacy_masked_lm_loss \
|
| 67 |
+
--max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
|
| 68 |
+
--dataset-impl lazy --seed 0 \
|
| 69 |
+
--masked-lm-only \
|
| 70 |
+
--monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
|
| 71 |
+
--ddp-backend=no_c10d
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
Some Notes:
|
| 75 |
+
- Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning.
|
| 76 |
+
- The Evaluation workflow for computing MLM Perplexity on test data is in progress.
|
| 77 |
+
- Finetuning this model on a downstream task is something which is not currently available.
|
fairseq-0.10.2/examples/latent_depth/latent_depth_src/loss/latent_depth.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn.modules.loss import _Loss
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LatentLayersKLLoss(_Loss):
|
| 13 |
+
def __init__(self, args):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.args = args
|
| 16 |
+
|
| 17 |
+
def forward(self, layer_samples, lang_idx, update_num, sample_size):
|
| 18 |
+
prior = self.args.prior
|
| 19 |
+
samples = layer_samples[lang_idx]
|
| 20 |
+
eps = 1e-7
|
| 21 |
+
if prior == "uniform":
|
| 22 |
+
# uniform prior
|
| 23 |
+
kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
|
| 24 |
+
elif prior == "agged_posterior":
|
| 25 |
+
# aggregated posterior
|
| 26 |
+
y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
|
| 27 |
+
agged_q = torch.sum(y_t, dim=0)
|
| 28 |
+
row_norm = agged_q.sum(-1)
|
| 29 |
+
normed_agg_q = agged_q / row_norm
|
| 30 |
+
kl_loss = (
|
| 31 |
+
samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
|
| 32 |
+
).sum(-1)
|
| 33 |
+
else:
|
| 34 |
+
raise NotImplementedError("The specified prior is not implemented.")
|
| 35 |
+
|
| 36 |
+
# normalized by number of layers
|
| 37 |
+
kl_loss /= layer_samples[0].size()[0]
|
| 38 |
+
kl_weight = min(
|
| 39 |
+
self.args.sparsity_weight,
|
| 40 |
+
(update_num - self.args.soft_update)
|
| 41 |
+
* self.args.sparsity_weight
|
| 42 |
+
/ self.args.anneal_updates,
|
| 43 |
+
)
|
| 44 |
+
kl_loss *= kl_weight * sample_size
|
| 45 |
+
return kl_loss
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class LatentLayersSparsityLoss(_Loss):
|
| 49 |
+
def __init__(self, args):
|
| 50 |
+
super().__init__()
|
| 51 |
+
self.args = args
|
| 52 |
+
|
| 53 |
+
def is_valid(self, update_num):
|
| 54 |
+
if self.args.target_layers <= 0:
|
| 55 |
+
return False
|
| 56 |
+
return update_num > (self.args.soft_update + self.args.anneal_updates)
|
| 57 |
+
|
| 58 |
+
def forward(self, layer_samples_list, update_num, sample_size):
|
| 59 |
+
batch_loss = 0
|
| 60 |
+
share_loss = 0
|
| 61 |
+
global_sparsity_loss = 0
|
| 62 |
+
layer_samples = torch.stack(layer_samples_list, dim=0)
|
| 63 |
+
if (
|
| 64 |
+
self.args.target_layers > 0 or self.args.share_weight > 0
|
| 65 |
+
) and update_num > (self.args.soft_update + self.args.anneal_updates):
|
| 66 |
+
# anneal sparsity weight
|
| 67 |
+
if update_num < (self.args.anneal_updates + self.args.soft_update):
|
| 68 |
+
weight_anneal = 0
|
| 69 |
+
elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
|
| 70 |
+
weight_anneal = (
|
| 71 |
+
(update_num - self.args.soft_update - self.args.anneal_updates)
|
| 72 |
+
* self.args.share_weight
|
| 73 |
+
/ self.args.anneal_updates
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
weight_anneal = 1
|
| 77 |
+
# compute ratio among languages
|
| 78 |
+
layer_utilization = torch.sum(layer_samples, dim=0)
|
| 79 |
+
layer_utilization /= layer_samples.size()[0]
|
| 80 |
+
if self.args.share_weight > 0:
|
| 81 |
+
# encouraging sharing across languages
|
| 82 |
+
share_loss = sum(
|
| 83 |
+
-1.0 * v * math.log(v) for v in layer_utilization if v > 0
|
| 84 |
+
)
|
| 85 |
+
batch_loss += (
|
| 86 |
+
weight_anneal * self.args.share_weight * sample_size * share_loss
|
| 87 |
+
)
|
| 88 |
+
if self.args.target_layers > 0:
|
| 89 |
+
# computed expected number of layers selected
|
| 90 |
+
expeted_layers = sum(layer_utilization)
|
| 91 |
+
# compute l2 loss wrt target number of layers
|
| 92 |
+
global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
|
| 93 |
+
batch_loss += (
|
| 94 |
+
weight_anneal
|
| 95 |
+
* self.args.share_weight
|
| 96 |
+
* sample_size
|
| 97 |
+
* global_sparsity_loss
|
| 98 |
+
)
|
| 99 |
+
return batch_loss
|
fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/__init__.py
ADDED
|
File without changes
|
fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq.models import register_model, register_model_architecture
|
| 7 |
+
from fairseq.models.multilingual_transformer import MultilingualTransformerModel
|
| 8 |
+
from fairseq.models.transformer import (
|
| 9 |
+
TransformerDecoder,
|
| 10 |
+
TransformerEncoder,
|
| 11 |
+
base_architecture,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_model("latent_multilingual_transformer")
|
| 18 |
+
class LatentMultilingualTransformerModel(MultilingualTransformerModel):
|
| 19 |
+
"""A variant of standard multilingual Transformer models which encoder and/or
|
| 20 |
+
decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
|
| 21 |
+
(https://arxiv.org/abs/2009.13102).
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
@classmethod
|
| 25 |
+
def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
|
| 26 |
+
if is_encoder:
|
| 27 |
+
if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
|
| 28 |
+
return LatentTransformerEncoder(
|
| 29 |
+
args, lang_dict, embed_tokens, num_logits=len(langs)
|
| 30 |
+
)
|
| 31 |
+
else:
|
| 32 |
+
return TransformerEncoder(args, lang_dict, embed_tokens)
|
| 33 |
+
else:
|
| 34 |
+
if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
|
| 35 |
+
return LatentTransformerDecoder(
|
| 36 |
+
args, lang_dict, embed_tokens, num_logits=len(langs)
|
| 37 |
+
)
|
| 38 |
+
else:
|
| 39 |
+
return TransformerDecoder(args, lang_dict, embed_tokens)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@register_model_architecture(
|
| 43 |
+
"latent_multilingual_transformer", "latent_multilingual_transformer"
|
| 44 |
+
)
|
| 45 |
+
def latent_multilingual_architecture(args):
|
| 46 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 47 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
| 48 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
| 49 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 50 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
| 51 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
| 52 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
| 53 |
+
args.decoder_layers = getattr(args, "decoder_layers", 24)
|
| 54 |
+
args.share_encoders = getattr(args, "share_encoders", True)
|
| 55 |
+
args.share_decoders = getattr(args, "share_decoders", True)
|
| 56 |
+
args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
|
| 57 |
+
args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
|
| 58 |
+
|
| 59 |
+
base_architecture(args)
|
fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_transformer.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, Optional
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from fairseq.models.fairseq_encoder import EncoderOut
|
| 10 |
+
from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
|
| 11 |
+
from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
|
| 14 |
+
from ..modules.latent_layers import LayerSelect
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LatentTransformerEncoder(TransformerEncoder):
|
| 18 |
+
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
| 19 |
+
TransformerEncoder.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, args, dictionary, embed_tokens, num_logits=1):
|
| 23 |
+
self.num_logits = num_logits
|
| 24 |
+
self.num_layers = args.encoder_layers
|
| 25 |
+
super().__init__(args, dictionary, embed_tokens)
|
| 26 |
+
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
| 27 |
+
self.lang_idx = None
|
| 28 |
+
self.layers = nn.ModuleList(
|
| 29 |
+
[self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
def set_lang_idx(self, lang_idx):
|
| 33 |
+
self.lang_idx = lang_idx
|
| 34 |
+
|
| 35 |
+
def _build_encoder_layer(self, args, idx=None):
|
| 36 |
+
return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)
|
| 37 |
+
|
| 38 |
+
def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
|
| 39 |
+
self.layer_select.sample(self.lang_idx)
|
| 40 |
+
return super().forward(src_tokens, src_lengths, return_all_hiddens)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LatentTransformerEncoderLayer(TransformerEncoderLayer):
|
| 44 |
+
"""Encoder layer with each (non_residual) block weighted by samples of Bernouli
|
| 45 |
+
or Gumbel Signmoid samples.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
args (argparse.Namespace): parsed command-line arguments from standard
|
| 49 |
+
TransformerEncoderLayer.
|
| 50 |
+
idx (int): layer index (used to retrieve samples).
|
| 51 |
+
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
|
| 52 |
+
parameters and sampling method.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, args, idx, layer_select=None):
|
| 56 |
+
super().__init__(args)
|
| 57 |
+
self.idx = idx
|
| 58 |
+
self.layer_select = layer_select
|
| 59 |
+
|
| 60 |
+
def residual_connection(self, x, residual):
|
| 61 |
+
return residual + x * self.layer_select(self.idx)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LatentTransformerDecoder(TransformerDecoder):
|
| 65 |
+
"""Latent depth (https://arxiv.org/abs/2009.13102) implemented in
|
| 66 |
+
TransformerDecoder.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def __init__(
|
| 70 |
+
self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
|
| 71 |
+
):
|
| 72 |
+
self.num_logits = num_logits
|
| 73 |
+
self.num_layers = args.decoder_layers
|
| 74 |
+
super().__init__(
|
| 75 |
+
args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
|
| 76 |
+
)
|
| 77 |
+
self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
|
| 78 |
+
self.lang_idx = None
|
| 79 |
+
self.layers = nn.ModuleList(
|
| 80 |
+
[
|
| 81 |
+
self._build_decoder_layer(args, no_encoder_attn, idx)
|
| 82 |
+
for idx in range(args.decoder_layers)
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def set_lang_idx(self, lang_idx):
|
| 87 |
+
self.lang_idx = lang_idx
|
| 88 |
+
|
| 89 |
+
def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
|
| 90 |
+
return LatentTransformerDecoderLayer(
|
| 91 |
+
args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def forward(
|
| 95 |
+
self,
|
| 96 |
+
prev_output_tokens,
|
| 97 |
+
encoder_out: Optional[EncoderOut] = None,
|
| 98 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 99 |
+
features_only: bool = False,
|
| 100 |
+
alignment_layer: Optional[int] = None,
|
| 101 |
+
alignment_heads: Optional[int] = None,
|
| 102 |
+
src_lengths: Optional[Any] = None,
|
| 103 |
+
return_all_hiddens: bool = False,
|
| 104 |
+
):
|
| 105 |
+
self.layer_select.sample(self.lang_idx)
|
| 106 |
+
return super().forward(
|
| 107 |
+
prev_output_tokens=prev_output_tokens,
|
| 108 |
+
encoder_out=encoder_out,
|
| 109 |
+
incremental_state=incremental_state,
|
| 110 |
+
features_only=features_only,
|
| 111 |
+
alignment_layer=alignment_layer,
|
| 112 |
+
src_lengths=src_lengths,
|
| 113 |
+
return_all_hiddens=return_all_hiddens,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class LatentTransformerDecoderLayer(TransformerDecoderLayer):
|
| 118 |
+
"""Decoder layer with each (non_residual) block weighted by samples of Bernouli
|
| 119 |
+
or Gumbel Signmoid samples.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
args (argparse.Namespace): parsed command-line arguments from standard
|
| 123 |
+
TransformerDecoderLayer.
|
| 124 |
+
idx (int): layer index (used to retrieve samples).
|
| 125 |
+
layer_select (LayerSelect, optional): instance of LayerSelect module with logits
|
| 126 |
+
parameters and sampling method.
|
| 127 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 128 |
+
(default: False).
|
| 129 |
+
|
| 130 |
+
"""
|
| 131 |
+
|
| 132 |
+
def __init__(
|
| 133 |
+
self,
|
| 134 |
+
args,
|
| 135 |
+
idx,
|
| 136 |
+
layer_select=None,
|
| 137 |
+
no_encoder_attn=False,
|
| 138 |
+
add_bias_kv=False,
|
| 139 |
+
add_zero_attn=False,
|
| 140 |
+
):
|
| 141 |
+
super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
|
| 142 |
+
self.idx = idx
|
| 143 |
+
self.layer_select = layer_select
|
| 144 |
+
|
| 145 |
+
def residual_connection(self, x, residual):
|
| 146 |
+
return residual + x * self.layer_select(self.idx)
|
fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/__init__.py
ADDED
|
File without changes
|
fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/latent_layers.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LayerSelect(nn.Module):
|
| 11 |
+
"""Compute samples (from a Gumbel-Sigmoid distribution) which is used as
|
| 12 |
+
either (soft) weighting or (hard) selection of residual connection.
|
| 13 |
+
https://arxiv.org/abs/2009.13102
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def __init__(self, num_layers, num_logits, args):
|
| 17 |
+
super(LayerSelect, self).__init__()
|
| 18 |
+
self.args = args
|
| 19 |
+
self.layer_logits = torch.nn.Parameter(
|
| 20 |
+
torch.Tensor(num_logits, num_layers),
|
| 21 |
+
requires_grad=True,
|
| 22 |
+
)
|
| 23 |
+
self.hard_select = not (hasattr(args, "soft_select") and args.soft_select)
|
| 24 |
+
self.tau = getattr(args, "sampling_tau", 5)
|
| 25 |
+
self.detach_grad = False
|
| 26 |
+
self.layer_samples = [None] * num_logits
|
| 27 |
+
|
| 28 |
+
@staticmethod
|
| 29 |
+
def add_args(parser):
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--soft-select",
|
| 32 |
+
action="store_true",
|
| 33 |
+
help="use soft samples in training an inference",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
|
| 36 |
+
|
| 37 |
+
def sample(self, logit_idx):
|
| 38 |
+
"""To leverage the efficiency of distributed training, samples for all
|
| 39 |
+
layers are computed at once for each logit_idx. Logits are parameters
|
| 40 |
+
learnt independent of each other.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
logit_idx: The index of logit parameters used for sampling.
|
| 44 |
+
"""
|
| 45 |
+
assert logit_idx is not None
|
| 46 |
+
self.samples = self._gumbel_sigmoid(
|
| 47 |
+
self.layer_logits[logit_idx, :].detach()
|
| 48 |
+
if self.detach_grad
|
| 49 |
+
else self.layer_logits[logit_idx, :],
|
| 50 |
+
dim=-1,
|
| 51 |
+
tau=self.tau,
|
| 52 |
+
hard=self.hard_select,
|
| 53 |
+
)
|
| 54 |
+
self.layer_samples[logit_idx] = self.samples
|
| 55 |
+
|
| 56 |
+
def forward(self, i):
|
| 57 |
+
sample = self.samples[i]
|
| 58 |
+
return sample
|
| 59 |
+
|
| 60 |
+
def _gumbel_sigmoid(
|
| 61 |
+
self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
|
| 62 |
+
):
|
| 63 |
+
# ~Gumbel(0,1)
|
| 64 |
+
gumbels1 = (
|
| 65 |
+
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
| 66 |
+
.exponential_()
|
| 67 |
+
.log()
|
| 68 |
+
)
|
| 69 |
+
gumbels2 = (
|
| 70 |
+
-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
|
| 71 |
+
.exponential_()
|
| 72 |
+
.log()
|
| 73 |
+
)
|
| 74 |
+
# Difference of two gumbels because we apply a sigmoid
|
| 75 |
+
gumbels1 = (logits + gumbels1 - gumbels2) / tau
|
| 76 |
+
y_soft = gumbels1.sigmoid()
|
| 77 |
+
if hard:
|
| 78 |
+
# Straight through.
|
| 79 |
+
y_hard = torch.zeros_like(
|
| 80 |
+
logits, memory_format=torch.legacy_contiguous_format
|
| 81 |
+
).masked_fill(y_soft > threshold, 1.0)
|
| 82 |
+
ret = y_hard - y_soft.detach() + y_soft
|
| 83 |
+
else:
|
| 84 |
+
# Reparametrization trick.
|
| 85 |
+
ret = y_soft
|
| 86 |
+
return ret
|
fairseq-0.10.2/examples/m2m_100/README.md
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Beyond English-Centric Multilingual Machine Translation
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
In this work, we create a true Many-to-Many multilingual translation model that can translate directly between any pair of 100 languages. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly translating between non-English directions while performing competitively with the best single systems of WMT.
|
| 5 |
+
|
| 6 |
+
If you are new to using fairseq, read the following walkthrough. Otherwise, skip to the sections below.
|
| 7 |
+
|
| 8 |
+
0. **Generation Data**
|
| 9 |
+
|
| 10 |
+
To download the generation data, follow the below commands. Note that all datasets need to be detokenized *before* applying SPM in the data preprocessing step. If you use these evaluation datasets, please cite their associated papers.
|
| 11 |
+
```bash
|
| 12 |
+
# WMT - use sacrebleu, example here:
|
| 13 |
+
sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr
|
| 14 |
+
sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en
|
| 15 |
+
|
| 16 |
+
# WAT
|
| 17 |
+
wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2019.my-en.zip
|
| 18 |
+
unzip wat2019.my-en.zip
|
| 19 |
+
|
| 20 |
+
# FLORES
|
| 21 |
+
# download from: https://github.com/facebookresearch/flores
|
| 22 |
+
|
| 23 |
+
# TED - need to detokenize with Moses!
|
| 24 |
+
# from: https://github.com/neulab/word-embeddings-for-nmt
|
| 25 |
+
wget http://phontron.com/data/ted_talks.tar.gz
|
| 26 |
+
|
| 27 |
+
# Autshumato
|
| 28 |
+
# request to download: https://repo.sadilar.org/handle/20.500.12185/397
|
| 29 |
+
|
| 30 |
+
# Tatoeba Challenge
|
| 31 |
+
# available here: https://github.com/Helsinki-NLP/Tatoeba-Challenge
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
1. **Training Data**
|
| 35 |
+
|
| 36 |
+
To produce the training data, we use a combination of [CCMatrix](https://arxiv.org/abs/1911.04944) and [CCAligned](https://arxiv.org/abs/1911.06154). Check out the instructions [here](https://github.com/facebookresearch/LASER/tree/master/tasks/CCMatrix) to download the raw data.
|
| 37 |
+
|
| 38 |
+
2. **Preprocess Data**
|
| 39 |
+
|
| 40 |
+
After downloading raw data, you will need to postprocess the data, then apply SPM, then binarize. Note that it is very important you run the postprocessing script, because this removes any instance of the evaluation data in the mined training data.
|
| 41 |
+
|
| 42 |
+
```bash
|
| 43 |
+
# preprocess data
|
| 44 |
+
|
| 45 |
+
# remove sentences with more than 50% punctuation
|
| 46 |
+
python /path/to/fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
|
| 47 |
+
|
| 48 |
+
# deduplicate training data
|
| 49 |
+
paste /path/to/datadir/train.$src /path/to/datadir/train.$tgt | awk '!x[$0]++' > /path/to/datadir/train.dedup
|
| 50 |
+
echo "keeping $(wc -l /path/to/datadir/train.dedup) bitext out of $(wc -l /path/to/datadir/train.$src)"
|
| 51 |
+
cut -f1 /path/to/datadir/train.dedup > /path/to/datadir/train.$src
|
| 52 |
+
cut -f2 /path/to/datadir/train.dedup > /path/to/datadir/train.$tgt
|
| 53 |
+
|
| 54 |
+
# remove all instances of evaluation data from the training data
|
| 55 |
+
python /path/to/fairseq/examples/m2m_100/process_data/dedup_data.py
|
| 56 |
+
|
| 57 |
+
# frequency cleaning
|
| 58 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz
|
| 59 |
+
tar -xvzf histograms.tar.gz
|
| 60 |
+
python /path/to/fairseq/examples/m2m_100/process_data/clean_histogram.py --src $src --tgt $tgt --src-file /path/to/source/file --tgt-file /path/to/output/file --src-output-file source_output.$src --tgt-output-file target_output.$tgt --histograms /path/to/histograms
|
| 61 |
+
|
| 62 |
+
# apply SPM
|
| 63 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
|
| 64 |
+
python /path/to/fairseq/scripts/spm_encode.py \
|
| 65 |
+
--model spm.128k.model \
|
| 66 |
+
--output_format=piece \
|
| 67 |
+
--inputs=/path/to/input/file/here \
|
| 68 |
+
--outputs=/path/to/output/file/here
|
| 69 |
+
|
| 70 |
+
# length ratio cleaning
|
| 71 |
+
perl mosesdecoder/scripts/training/clean-corpus-n.perl --ratio 3 /path/to/training/data/train.spm.$src-$tgt $src $tgt /path/to/output/directory/train.spm.$src-$tgt 1 250
|
| 72 |
+
|
| 73 |
+
# binarize data
|
| 74 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
|
| 75 |
+
fairseq-preprocess \
|
| 76 |
+
--source-lang $src --target-lang $tgt \
|
| 77 |
+
--testpref spm.$src.$tgt \
|
| 78 |
+
--thresholdsrc 0 --thresholdtgt 0 \
|
| 79 |
+
--destdir data_bin \
|
| 80 |
+
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
3. **Training Scripts**
|
| 84 |
+
|
| 85 |
+
To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/master/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale).
|
| 86 |
+
|
| 87 |
+
4. **Generation**
|
| 88 |
+
|
| 89 |
+
To generate from our models, follow the the commands in the generation section below.
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
If you use any of the resources listed here, please cite:
|
| 93 |
+
```bibtex
|
| 94 |
+
@article{fan2020beyond,
|
| 95 |
+
title={Beyond English-Centric Multilingual Machine Translation},
|
| 96 |
+
author={Fan, Angela and Bhosale, Shruti and Schwenk, Holger and Ma, Zhiyi and El-Kishky, Ahmed and Goyal, Siddharth and Baines, Mandeep and Celebi, Onur and Wenzek, Guillaume and Chaudhary, Vishrav and Goyal, Naman and Birch, Tom and Liptchinsky, Vitaliy and Edunov, Sergey and Grave, Edouard and Auli, Michael and Joulin, Armand},
|
| 97 |
+
journal={arXiv preprint},
|
| 98 |
+
year={2020}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
@article{schwenk2019ccmatrix,
|
| 102 |
+
title={Ccmatrix: Mining billions of high-quality parallel sentences on the web},
|
| 103 |
+
author={Schwenk, Holger and Wenzek, Guillaume and Edunov, Sergey and Grave, Edouard and Joulin, Armand},
|
| 104 |
+
journal={arXiv preprint arXiv:1911.04944},
|
| 105 |
+
year={2019}
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
@article{el2019massive,
|
| 109 |
+
title={A Massive Collection of Cross-Lingual Web-Document Pairs},
|
| 110 |
+
author={El-Kishky, Ahmed and Chaudhary, Vishrav and Guzman, Francisco and Koehn, Philipp},
|
| 111 |
+
journal={arXiv preprint arXiv:1911.06154},
|
| 112 |
+
year={2019}
|
| 113 |
+
}
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
## Trained Models
|
| 118 |
+
|
| 119 |
+
Looking for other trained models? Check back soon.
|
| 120 |
+
|
| 121 |
+
Model | Description | Download
|
| 122 |
+
---|---|---
|
| 123 |
+
`12b_last_checkpoint` | 12B parameter model trained on many-to-many training data for 100 languages | [12b_last_checkpoint](https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## SentencePiece Model
|
| 127 |
+
|
| 128 |
+
```bash
|
| 129 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
|
| 130 |
+
```
|
| 131 |
+
|
| 132 |
+
## Generation with M2M-100
|
| 133 |
+
|
| 134 |
+
### Encode using our SentencePiece Model
|
| 135 |
+
|
| 136 |
+
Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
fairseq=/path/to/fairseq
|
| 140 |
+
cd $fairseq
|
| 141 |
+
sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
|
| 142 |
+
sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
|
| 143 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
|
| 144 |
+
for lang in de fr ; do
|
| 145 |
+
python scripts/spm_encode.py \
|
| 146 |
+
--model spm.128k.model \
|
| 147 |
+
--output_format=piece \
|
| 148 |
+
--inputs=raw_input.de-fr.${lang} \
|
| 149 |
+
--outputs=spm.de-fr.${lang}
|
| 150 |
+
done
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
### Binarization
|
| 154 |
+
|
| 155 |
+
```bash
|
| 156 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
|
| 157 |
+
fairseq-preprocess \
|
| 158 |
+
--source-lang de --target-lang fr \
|
| 159 |
+
--testpref spm.de-fr \
|
| 160 |
+
--thresholdsrc 0 --thresholdtgt 0 \
|
| 161 |
+
--destdir data_bin \
|
| 162 |
+
--srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
### Generation on a V100 GPU
|
| 166 |
+
|
| 167 |
+
```bash
|
| 168 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
|
| 169 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt
|
| 170 |
+
wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt
|
| 171 |
+
fairseq-generate \
|
| 172 |
+
data_bin \
|
| 173 |
+
--batch-size 1 \
|
| 174 |
+
--path 12b_last_checkpoint.pt \
|
| 175 |
+
--fixed-dictionary model_dict.128k.txt \
|
| 176 |
+
-s de -t fr \
|
| 177 |
+
--remove-bpe 'sentencepiece' \
|
| 178 |
+
--beam 5 \
|
| 179 |
+
--task translation_multi_simple_epoch \
|
| 180 |
+
--lang-pairs language_pairs.txt \
|
| 181 |
+
--decoder-langtok --encoder-langtok src \
|
| 182 |
+
--gen-subset test \
|
| 183 |
+
--fp16 \
|
| 184 |
+
--dataset-impl mmap \
|
| 185 |
+
--distributed-world-size 1 --distributed-no-spawn \
|
| 186 |
+
--pipeline-model-parallel \
|
| 187 |
+
--pipeline-chunks 1 \
|
| 188 |
+
--pipeline-encoder-balance '[26]' \
|
| 189 |
+
--pipeline-encoder-devices '[0]' \
|
| 190 |
+
--pipeline-decoder-balance '[1,24,1]' \
|
| 191 |
+
--pipeline-decoder-devices '[0,1,0]' > gen_out
|
| 192 |
+
```
|
| 193 |
+
## Evaluation with M2M-100
|
| 194 |
+
|
| 195 |
+
### Tokenization
|
| 196 |
+
|
| 197 |
+
Note: Refer to tokenizers/README.md for more details on tokenization.
|
| 198 |
+
|
| 199 |
+
```bash
|
| 200 |
+
cd ${fairseq}/examples/m2m_100
|
| 201 |
+
cat ${fairseq}/gen_out | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh fr > hyp
|
| 202 |
+
cat ${fairseq}/raw_input.de-fr.fr | sh tok.sh fr > ref
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### BLEU
|
| 206 |
+
|
| 207 |
+
```bash
|
| 208 |
+
sacrebleu -tok 'none' ref < hyp
|
| 209 |
+
```
|
fairseq-0.10.2/examples/m2m_100/install_dependecies.sh
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
CWD=`pwd`
|
| 9 |
+
INSTALL_PATH=$CWD/tokenizers/thirdparty
|
| 10 |
+
|
| 11 |
+
MOSES=$INSTALL_PATH/mosesdecoder
|
| 12 |
+
if [ ! -d $MOSES ]; then
|
| 13 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
| 14 |
+
git clone https://github.com/moses-smt/mosesdecoder.git $MOSES
|
| 15 |
+
cd $MOSES
|
| 16 |
+
# To deal with differences in handling ' vs "
|
| 17 |
+
git checkout 03578921cc1a03402
|
| 18 |
+
cd -
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts
|
| 22 |
+
if [ ! -d $WMT16_SCRIPTS ]; then
|
| 23 |
+
echo 'Cloning Romanian tokenization scripts'
|
| 24 |
+
git clone https://github.com/rsennrich/wmt16-scripts.git $WMT16_SCRIPTS
|
| 25 |
+
fi
|
| 26 |
+
|
| 27 |
+
KYTEA=$INSTALL_PATH/kytea
|
| 28 |
+
if [ ! -f $KYTEA/bin/kytea ]; then
|
| 29 |
+
git clone https://github.com/neubig/kytea.git $KYTEA
|
| 30 |
+
cd $KYTEA
|
| 31 |
+
autoreconf -i
|
| 32 |
+
./configure --prefix=`pwd`
|
| 33 |
+
make
|
| 34 |
+
make install
|
| 35 |
+
cd ..
|
| 36 |
+
fi
|
| 37 |
+
|
| 38 |
+
export MECAB=$INSTALL_PATH/mecab-0.996-ko-0.9.2
|
| 39 |
+
if [ ! -f $MECAB/bin/mecab ]; then
|
| 40 |
+
cd $INSTALL_PATH
|
| 41 |
+
curl -LO https://bitbucket.org/eunjeon/mecab-ko/downloads/mecab-0.996-ko-0.9.2.tar.gz
|
| 42 |
+
tar zxfv mecab-0.996-ko-0.9.2.tar.gz
|
| 43 |
+
cd mecab-0.996-ko-0.9.2/
|
| 44 |
+
./configure --prefix=`pwd`
|
| 45 |
+
make
|
| 46 |
+
make install
|
| 47 |
+
|
| 48 |
+
cd ..
|
| 49 |
+
curl -LO https://bitbucket.org/eunjeon/mecab-ko-dic/downloads/mecab-ko-dic-2.1.1-20180720.tar.gz
|
| 50 |
+
tar zxfv mecab-ko-dic-2.1.1-20180720.tar.gz
|
| 51 |
+
cd mecab-ko-dic-2.1.1-20180720/
|
| 52 |
+
./autogen.sh
|
| 53 |
+
./configure --prefix=`pwd` --with-dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic --with-mecab-config=$MECAB/bin/mecab-config
|
| 54 |
+
make
|
| 55 |
+
sh -c 'echo "dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic" > $MECAB/etc/mecabrc'
|
| 56 |
+
make install
|
| 57 |
+
cd $CWD
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources
|
| 61 |
+
if [ ! -d $INDIC_RESOURCES_PATH ]; then
|
| 62 |
+
echo 'Cloning indic_nlp_resources'
|
| 63 |
+
git clone https://github.com/anoopkunchukuttan/indic_nlp_resources.git $INDIC_RESOURCES_PATH
|
| 64 |
+
fi
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
if [ ! -f $INSTALL_PATH/seg_my.py ]; then
|
| 68 |
+
cd $INSTALL_PATH
|
| 69 |
+
wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip
|
| 70 |
+
unzip wat2020.my-en.zip
|
| 71 |
+
# switch to python3
|
| 72 |
+
cat wat2020.my-en/myseg.py |sed 's/^sys.std/###sys.std/g' | sed 's/### sys/sys/g' | sed 's/unichr/chr/g' > seg_my.py
|
| 73 |
+
cd $CWD
|
| 74 |
+
fi
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
pip install pythainlp sacrebleu indic-nlp-library
|
| 78 |
+
|
fairseq-0.10.2/examples/m2m_100/process_data/remove_too_much_punc.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gzip
|
| 2 |
+
import argparse
|
| 3 |
+
from string import punctuation
|
| 4 |
+
|
| 5 |
+
def len_no_punc(s, punc):
|
| 6 |
+
return len([ch for ch in s if ch in punc])
|
| 7 |
+
|
| 8 |
+
def filter_overpunc(len_npunc, len_sen):
|
| 9 |
+
return len_npunc < 0.5*len_sen
|
| 10 |
+
|
| 11 |
+
def main(args):
|
| 12 |
+
punc = punctuation + "—|–"
|
| 13 |
+
print('Processing file {}'.format(args.input))
|
| 14 |
+
with gzip.open(args.input, 'rt', encoding=args.encoding) as tsv:
|
| 15 |
+
with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc:
|
| 16 |
+
with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt:
|
| 17 |
+
line = tsv.readline()
|
| 18 |
+
fields = line.split('\t')
|
| 19 |
+
|
| 20 |
+
src, tgt = fields[1], fields[2]
|
| 21 |
+
|
| 22 |
+
nchar_npunc_src = len_no_punc(src, punc)
|
| 23 |
+
nchar_npunc_tgt = len_no_punc(tgt, punc)
|
| 24 |
+
|
| 25 |
+
if filter_overpunc(nchar_npunc_src, len(src)) and filter_overpunc(nchar_npunc_tgt, len(tgt)):
|
| 26 |
+
fsrc.write(src.strip() + '\n')
|
| 27 |
+
ftgt.write(tgt.strip() + '\n')
|
| 28 |
+
|
| 29 |
+
if __name__ == '__main__':
|
| 30 |
+
parser = argparse.ArgumentParser()
|
| 31 |
+
parser.add_argument("--input", required=True, type=str)
|
| 32 |
+
parser.add_argument('--encoding', default='utf-8', help='character encoding for input/output')
|
| 33 |
+
parser.add_argument('--bitext', type=str, required=True, help='language direction')
|
| 34 |
+
parser.add_argument('--src-lang', type=str, required=True, help='Source language')
|
| 35 |
+
parser.add_argument('--tgt-lang', type=str, required=True, help='Target language')
|
| 36 |
+
main(parser.parse_args())
|
fairseq-0.10.2/examples/m2m_100/tok.sh
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Copyright (c) 2019-present, Facebook, Inc.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
set -e
|
| 10 |
+
|
| 11 |
+
TOKENIZERS_SCRIPTS=tokenizers
|
| 12 |
+
INSTALL_PATH=$TOKENIZERS_SCRIPTS/thirdparty
|
| 13 |
+
|
| 14 |
+
N_THREADS=8
|
| 15 |
+
|
| 16 |
+
lg=$1
|
| 17 |
+
|
| 18 |
+
MOSES=$INSTALL_PATH/mosesdecoder
|
| 19 |
+
REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
|
| 20 |
+
NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
|
| 21 |
+
REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
|
| 22 |
+
TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
|
| 23 |
+
|
| 24 |
+
# special tokenization for Romanian
|
| 25 |
+
WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts
|
| 26 |
+
|
| 27 |
+
NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py
|
| 28 |
+
REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py
|
| 29 |
+
|
| 30 |
+
# Burmese
|
| 31 |
+
MY_SEGMENT=$INSTALL_PATH/seg_my.py
|
| 32 |
+
|
| 33 |
+
# Arabic
|
| 34 |
+
AR_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenizer_ar.sh
|
| 35 |
+
|
| 36 |
+
# Korean
|
| 37 |
+
KO_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ko.sh
|
| 38 |
+
|
| 39 |
+
# Japanese
|
| 40 |
+
JA_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ja.sh
|
| 41 |
+
|
| 42 |
+
# Indic
|
| 43 |
+
IN_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_indic.py
|
| 44 |
+
INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources
|
| 45 |
+
|
| 46 |
+
# Thai
|
| 47 |
+
THAI_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_thai.py
|
| 48 |
+
|
| 49 |
+
# Chinese
|
| 50 |
+
CHINESE_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_zh.py
|
| 51 |
+
|
| 52 |
+
# Chinese
|
| 53 |
+
if [ "$lg" = "zh" ]; then
|
| 54 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $CHINESE_TOKENIZER
|
| 55 |
+
# Thai
|
| 56 |
+
elif [ "$lg" = "th" ]; then
|
| 57 |
+
cat - | python $THAI_TOKENIZER
|
| 58 |
+
# Japanese
|
| 59 |
+
elif [ "$lg" = "ja" ]; then
|
| 60 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | ${JA_SEGMENT}
|
| 61 |
+
# Korean
|
| 62 |
+
elif [ "$lg" = "ko" ]; then
|
| 63 |
+
cat - | $REM_NON_PRINT_CHAR | ${KO_SEGMENT}
|
| 64 |
+
# Romanian
|
| 65 |
+
elif [ "$lg" = "ro" ]; then
|
| 66 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
|
| 67 |
+
# Burmese
|
| 68 |
+
elif [ "$lg" = "my" ]; then
|
| 69 |
+
cat - | python ${MY_SEGMENT}
|
| 70 |
+
# Arabic
|
| 71 |
+
elif [ "$lg" = "ar" ]; then
|
| 72 |
+
cat - | ${AR_TOKENIZER}
|
| 73 |
+
# Indic
|
| 74 |
+
elif [ "$lg" = "ne" ]; then
|
| 75 |
+
cat - | python ${IN_TOKENIZER} $lg
|
| 76 |
+
elif [ "$lg" = "si" ]; then
|
| 77 |
+
cat - | python ${IN_TOKENIZER} $lg
|
| 78 |
+
elif [ "$lg" = "hi" ]; then
|
| 79 |
+
cat - | python ${IN_TOKENIZER} $lg
|
| 80 |
+
# other languages
|
| 81 |
+
else
|
| 82 |
+
cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
|
| 83 |
+
fi
|
fairseq-0.10.2/examples/mbart/README.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MBART: Multilingual Denoising Pre-training for Neural Machine Translation
|
| 2 |
+
[https://arxiv.org/abs/2001.08210]
|
| 3 |
+
|
| 4 |
+
## Introduction
|
| 5 |
+
|
| 6 |
+
MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text.
|
| 7 |
+
|
| 8 |
+
## Pre-trained models
|
| 9 |
+
|
| 10 |
+
Model | Description | # params | Download
|
| 11 |
+
---|---|---|---
|
| 12 |
+
`mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz)
|
| 13 |
+
`mbart.ft.ro_en` | finetune mBART cc25 model on ro-en language pairs | 610M | [mbart.cc25.ft.enro.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz)
|
| 14 |
+
|
| 15 |
+
## Results
|
| 16 |
+
|
| 17 |
+
**[WMT16 EN-RO](https://www.statmt.org/wmt16/translation-task.html)**
|
| 18 |
+
|
| 19 |
+
_(test set, no additional data used)_
|
| 20 |
+
|
| 21 |
+
Model | en-ro | ro-en
|
| 22 |
+
---|---|---
|
| 23 |
+
`Random` | 34.3 | 34.0
|
| 24 |
+
`mbart.cc25` | 37.7 | 37.8
|
| 25 |
+
`mbart.enro.bilingual` | 38.5 | 38.5
|
| 26 |
+
|
| 27 |
+
## BPE data
|
| 28 |
+
# download model
|
| 29 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz
|
| 30 |
+
tar -xzvf mbart.CC25.tar.gz
|
| 31 |
+
# bpe data
|
| 32 |
+
install SPM [here](https://github.com/google/sentencepiece)
|
| 33 |
+
```bash
|
| 34 |
+
SPM=/path/to/sentencepiece/build/src/spm_encode
|
| 35 |
+
MODEL=sentence.bpe.model
|
| 36 |
+
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${SRC} > ${DATA}/${TRAIN}.spm.${SRC} &
|
| 37 |
+
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${TGT} > ${DATA}/${TRAIN}.spm.${TGT} &
|
| 38 |
+
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${SRC} > ${DATA}/${VALID}.spm.${SRC} &
|
| 39 |
+
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${TGT} > ${DATA}/${VALID}.spm.${TGT} &
|
| 40 |
+
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${SRC} > ${DATA}/${TEST}.spm.${SRC} &
|
| 41 |
+
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${TGT} > ${DATA}/${TEST}.spm.${TGT} &
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Preprocess data
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
DICT=dict.txt
|
| 48 |
+
fairseq-preprocess \
|
| 49 |
+
--source-lang ${SRC} \
|
| 50 |
+
--target-lang ${TGT} \
|
| 51 |
+
--trainpref ${DATA}/${TRAIN}.spm \
|
| 52 |
+
--validpref ${DATA}/${VALID}.spm \
|
| 53 |
+
--testpref ${DATA}/${TEST}.spm \
|
| 54 |
+
--destdir ${DEST}/${NAME} \
|
| 55 |
+
--thresholdtgt 0 \
|
| 56 |
+
--thresholdsrc 0 \
|
| 57 |
+
--srcdict ${DICT} \
|
| 58 |
+
--tgtdict ${DICT} \
|
| 59 |
+
--workers 70
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
## Finetune on EN-RO
|
| 63 |
+
Finetune on mbart CC25
|
| 64 |
+
|
| 65 |
+
```bash
|
| 66 |
+
PRETRAIN=mbart.cc25 # fix if you moved the downloaded checkpoint
|
| 67 |
+
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
|
| 68 |
+
|
| 69 |
+
fairseq-train path_2_data \
|
| 70 |
+
--encoder-normalize-before --decoder-normalize-before \
|
| 71 |
+
--arch mbart_large --layernorm-embedding \
|
| 72 |
+
--task translation_from_pretrained_bart \
|
| 73 |
+
--source-lang en_XX --target-lang ro_RO \
|
| 74 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
|
| 75 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
| 76 |
+
--lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \
|
| 77 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
| 78 |
+
--max-tokens 1024 --update-freq 2 \
|
| 79 |
+
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
|
| 80 |
+
--seed 222 --log-format simple --log-interval 2 \
|
| 81 |
+
--restore-file $PRETRAIN \
|
| 82 |
+
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
|
| 83 |
+
--langs $langs \
|
| 84 |
+
--ddp-backend no_c10d
|
| 85 |
+
```
|
| 86 |
+
## Generate on EN-RO
|
| 87 |
+
Get sacrebleu on finetuned en-ro model
|
| 88 |
+
|
| 89 |
+
get tokenizer [here](https://github.com/rsennrich/wmt16-scripts)
|
| 90 |
+
```bash
|
| 91 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz
|
| 92 |
+
tar -xzvf mbart.cc25.ft.enro.tar.gz
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
```bash
|
| 96 |
+
model_dir=MBART_finetuned_enro # fix if you moved the checkpoint
|
| 97 |
+
|
| 98 |
+
fairseq-generate path_2_data \
|
| 99 |
+
--path $model_dir/model.pt \
|
| 100 |
+
--task translation_from_pretrained_bart \
|
| 101 |
+
--gen-subset test \
|
| 102 |
+
-t ro_RO -s en_XX \
|
| 103 |
+
--bpe 'sentencepiece' --sentencepiece-model $model_dir/sentence.bpe.model \
|
| 104 |
+
--sacrebleu --remove-bpe 'sentencepiece' \
|
| 105 |
+
--batch-size 32 --langs $langs > en_ro
|
| 106 |
+
|
| 107 |
+
cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp
|
| 108 |
+
cat en_ro | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.ref
|
| 109 |
+
sacrebleu -tok 'none' -s 'none' en_ro.ref < en_ro.hyp
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
## Citation
|
| 113 |
+
|
| 114 |
+
```bibtex
|
| 115 |
+
@article{liu2020multilingual,
|
| 116 |
+
title={Multilingual Denoising Pre-training for Neural Machine Translation},
|
| 117 |
+
author={Yinhan Liu and Jiatao Gu and Naman Goyal and Xian Li and Sergey Edunov and Marjan Ghazvininejad and Mike Lewis and Luke Zettlemoyer},
|
| 118 |
+
year={2020},
|
| 119 |
+
eprint={2001.08210},
|
| 120 |
+
archivePrefix={arXiv},
|
| 121 |
+
primaryClass={cs.CL}
|
| 122 |
+
}
|
| 123 |
+
```
|
fairseq-0.10.2/examples/roberta/README.glue.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finetuning RoBERTa on GLUE tasks
|
| 2 |
+
|
| 3 |
+
### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
|
| 4 |
+
```bash
|
| 5 |
+
wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
|
| 6 |
+
python download_glue_data.py --data_dir glue_data --tasks all
|
| 7 |
+
```
|
| 8 |
+
|
| 9 |
+
### 2) Preprocess GLUE task data:
|
| 10 |
+
```bash
|
| 11 |
+
./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
|
| 12 |
+
```
|
| 13 |
+
`glue_task_name` is one of the following:
|
| 14 |
+
`{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
|
| 15 |
+
Use `ALL` for preprocessing all the glue tasks.
|
| 16 |
+
|
| 17 |
+
### 3) Fine-tuning on GLUE task:
|
| 18 |
+
Example fine-tuning cmd for `RTE` task
|
| 19 |
+
```bash
|
| 20 |
+
TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
|
| 21 |
+
WARMUP_UPDATES=122 # 6 percent of the number of updates
|
| 22 |
+
LR=2e-05 # Peak LR for polynomial LR scheduler.
|
| 23 |
+
NUM_CLASSES=2
|
| 24 |
+
MAX_SENTENCES=16 # Batch size.
|
| 25 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
| 26 |
+
|
| 27 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \
|
| 28 |
+
--restore-file $ROBERTA_PATH \
|
| 29 |
+
--max-positions 512 \
|
| 30 |
+
--batch-size $MAX_SENTENCES \
|
| 31 |
+
--max-tokens 4400 \
|
| 32 |
+
--task sentence_prediction \
|
| 33 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 34 |
+
--required-batch-size-multiple 1 \
|
| 35 |
+
--init-token 0 --separator-token 2 \
|
| 36 |
+
--arch roberta_large \
|
| 37 |
+
--criterion sentence_prediction \
|
| 38 |
+
--num-classes $NUM_CLASSES \
|
| 39 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 40 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
| 41 |
+
--clip-norm 0.0 \
|
| 42 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
| 43 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
| 44 |
+
--max-epoch 10 \
|
| 45 |
+
--find-unused-parameters \
|
| 46 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
For each of the GLUE task, you will need to use following cmd-line arguments:
|
| 50 |
+
|
| 51 |
+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
|
| 52 |
+
---|---|---|---|---|---|---|---|---
|
| 53 |
+
`--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
|
| 54 |
+
`--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5
|
| 55 |
+
`--batch-size` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16
|
| 56 |
+
`--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598
|
| 57 |
+
`--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214
|
| 58 |
+
|
| 59 |
+
For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
|
| 60 |
+
|
| 61 |
+
**Note:**
|
| 62 |
+
|
| 63 |
+
a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=16/32` depending on the task.
|
| 64 |
+
|
| 65 |
+
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
|
| 66 |
+
|
| 67 |
+
c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
|
| 68 |
+
|
| 69 |
+
### Inference on GLUE task
|
| 70 |
+
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
from fairseq.models.roberta import RobertaModel
|
| 74 |
+
|
| 75 |
+
roberta = RobertaModel.from_pretrained(
|
| 76 |
+
'checkpoints/',
|
| 77 |
+
checkpoint_file='checkpoint_best.pt',
|
| 78 |
+
data_name_or_path='RTE-bin'
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
label_fn = lambda label: roberta.task.label_dictionary.string(
|
| 82 |
+
[label + roberta.task.label_dictionary.nspecial]
|
| 83 |
+
)
|
| 84 |
+
ncorrect, nsamples = 0, 0
|
| 85 |
+
roberta.cuda()
|
| 86 |
+
roberta.eval()
|
| 87 |
+
with open('glue_data/RTE/dev.tsv') as fin:
|
| 88 |
+
fin.readline()
|
| 89 |
+
for index, line in enumerate(fin):
|
| 90 |
+
tokens = line.strip().split('\t')
|
| 91 |
+
sent1, sent2, target = tokens[1], tokens[2], tokens[3]
|
| 92 |
+
tokens = roberta.encode(sent1, sent2)
|
| 93 |
+
prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
|
| 94 |
+
prediction_label = label_fn(prediction)
|
| 95 |
+
ncorrect += int(prediction_label == target)
|
| 96 |
+
nsamples += 1
|
| 97 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
| 98 |
+
|
| 99 |
+
```
|
fairseq-0.10.2/examples/roberta/README.md
ADDED
|
@@ -0,0 +1,296 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# RoBERTa: A Robustly Optimized BERT Pretraining Approach
|
| 2 |
+
|
| 3 |
+
https://arxiv.org/abs/1907.11692
|
| 4 |
+
|
| 5 |
+
## Introduction
|
| 6 |
+
|
| 7 |
+
RoBERTa iterates on BERT's pretraining procedure, including training the model longer, with bigger batches over more data; removing the next sentence prediction objective; training on longer sequences; and dynamically changing the masking pattern applied to the training data. See the associated paper for more details.
|
| 8 |
+
|
| 9 |
+
### What's New:
|
| 10 |
+
|
| 11 |
+
- January 2020: Italian model (UmBERTo) is available from Musixmatch Research: [UmBERTo](https://github.com/musixmatchresearch/umberto).
|
| 12 |
+
- November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/master/examples/camembert).
|
| 13 |
+
- November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/master/examples/xlmr).
|
| 14 |
+
- September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers).
|
| 15 |
+
- August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers).
|
| 16 |
+
- August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset).
|
| 17 |
+
- August 2019: Added [tutorial for pretraining RoBERTa using your own data](README.pretraining.md).
|
| 18 |
+
|
| 19 |
+
## Pre-trained models
|
| 20 |
+
|
| 21 |
+
Model | Description | # params | Download
|
| 22 |
+
---|---|---|---
|
| 23 |
+
`roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
|
| 24 |
+
`roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
|
| 25 |
+
`roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
|
| 26 |
+
`roberta.large.wsc` | `roberta.large` finetuned on [WSC](wsc/README.md) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
|
| 27 |
+
|
| 28 |
+
## Results
|
| 29 |
+
|
| 30 |
+
**[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
|
| 31 |
+
_(dev set, single model, single-task finetuning)_
|
| 32 |
+
|
| 33 |
+
Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
|
| 34 |
+
---|---|---|---|---|---|---|---|---
|
| 35 |
+
`roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
|
| 36 |
+
`roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
|
| 37 |
+
`roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
|
| 38 |
+
|
| 39 |
+
**[SuperGLUE (Wang et al., 2019)](https://super.gluebenchmark.com/)**
|
| 40 |
+
_(dev set, single model, single-task finetuning)_
|
| 41 |
+
|
| 42 |
+
Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC
|
| 43 |
+
---|---|---|---|---|---|---|---
|
| 44 |
+
`roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | -
|
| 45 |
+
`roberta.large.wsc` | - | - | - | - | - | - | 91.3
|
| 46 |
+
|
| 47 |
+
**[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
|
| 48 |
+
_(dev set, no additional data used)_
|
| 49 |
+
|
| 50 |
+
Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
|
| 51 |
+
---|---|---
|
| 52 |
+
`roberta.large` | 88.9/94.6 | 86.5/89.4
|
| 53 |
+
|
| 54 |
+
**[RACE (Lai et al., 2017)](http://www.qizhexie.com/data/RACE_leaderboard.html)**
|
| 55 |
+
_(test set)_
|
| 56 |
+
|
| 57 |
+
Model | Accuracy | Middle | High
|
| 58 |
+
---|---|---|---
|
| 59 |
+
`roberta.large` | 83.2 | 86.5 | 81.3
|
| 60 |
+
|
| 61 |
+
**[HellaSwag (Zellers et al., 2019)](https://rowanzellers.com/hellaswag/)**
|
| 62 |
+
_(test set)_
|
| 63 |
+
|
| 64 |
+
Model | Overall | In-domain | Zero-shot | ActivityNet | WikiHow
|
| 65 |
+
---|---|---|---|---|---
|
| 66 |
+
`roberta.large` | 85.2 | 87.3 | 83.1 | 74.6 | 90.9
|
| 67 |
+
|
| 68 |
+
**[Commonsense QA (Talmor et al., 2019)](https://www.tau-nlp.org/commonsenseqa)**
|
| 69 |
+
_(test set)_
|
| 70 |
+
|
| 71 |
+
Model | Accuracy
|
| 72 |
+
---|---
|
| 73 |
+
`roberta.large` (single model) | 72.1
|
| 74 |
+
`roberta.large` (ensemble) | 72.5
|
| 75 |
+
|
| 76 |
+
**[Winogrande (Sakaguchi et al., 2019)](https://arxiv.org/abs/1907.10641)**
|
| 77 |
+
_(test set)_
|
| 78 |
+
|
| 79 |
+
Model | Accuracy
|
| 80 |
+
---|---
|
| 81 |
+
`roberta.large` | 78.1
|
| 82 |
+
|
| 83 |
+
**[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**
|
| 84 |
+
_(TRANSLATE-TEST)_
|
| 85 |
+
|
| 86 |
+
Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
|
| 87 |
+
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
|
| 88 |
+
`roberta.large.mnli` | 91.3 | 82.91 | 84.27 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81
|
| 89 |
+
|
| 90 |
+
## Example usage
|
| 91 |
+
|
| 92 |
+
##### Load RoBERTa from torch.hub (PyTorch >= 1.1):
|
| 93 |
+
```python
|
| 94 |
+
import torch
|
| 95 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
|
| 96 |
+
roberta.eval() # disable dropout (or leave in train mode to finetune)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
##### Load RoBERTa (for PyTorch 1.0 or custom models):
|
| 100 |
+
```python
|
| 101 |
+
# Download roberta.large model
|
| 102 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
|
| 103 |
+
tar -xzvf roberta.large.tar.gz
|
| 104 |
+
|
| 105 |
+
# Load the model in fairseq
|
| 106 |
+
from fairseq.models.roberta import RobertaModel
|
| 107 |
+
roberta = RobertaModel.from_pretrained('/path/to/roberta.large', checkpoint_file='model.pt')
|
| 108 |
+
roberta.eval() # disable dropout (or leave in train mode to finetune)
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
##### Apply Byte-Pair Encoding (BPE) to input text:
|
| 112 |
+
```python
|
| 113 |
+
tokens = roberta.encode('Hello world!')
|
| 114 |
+
assert tokens.tolist() == [0, 31414, 232, 328, 2]
|
| 115 |
+
roberta.decode(tokens) # 'Hello world!'
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
##### Extract features from RoBERTa:
|
| 119 |
+
```python
|
| 120 |
+
# Extract the last layer's features
|
| 121 |
+
last_layer_features = roberta.extract_features(tokens)
|
| 122 |
+
assert last_layer_features.size() == torch.Size([1, 5, 1024])
|
| 123 |
+
|
| 124 |
+
# Extract all layer's features (layer 0 is the embedding layer)
|
| 125 |
+
all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
|
| 126 |
+
assert len(all_layers) == 25
|
| 127 |
+
assert torch.all(all_layers[-1] == last_layer_features)
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
##### Use RoBERTa for sentence-pair classification tasks:
|
| 131 |
+
```python
|
| 132 |
+
# Download RoBERTa already finetuned for MNLI
|
| 133 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
|
| 134 |
+
roberta.eval() # disable dropout for evaluation
|
| 135 |
+
|
| 136 |
+
# Encode a pair of sentences and make a prediction
|
| 137 |
+
tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.')
|
| 138 |
+
roberta.predict('mnli', tokens).argmax() # 0: contradiction
|
| 139 |
+
|
| 140 |
+
# Encode another pair of sentences
|
| 141 |
+
tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.')
|
| 142 |
+
roberta.predict('mnli', tokens).argmax() # 2: entailment
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
##### Register a new (randomly initialized) classification head:
|
| 146 |
+
```python
|
| 147 |
+
roberta.register_classification_head('new_task', num_classes=3)
|
| 148 |
+
logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
##### Batched prediction:
|
| 152 |
+
```python
|
| 153 |
+
import torch
|
| 154 |
+
from fairseq.data.data_utils import collate_tokens
|
| 155 |
+
|
| 156 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
|
| 157 |
+
roberta.eval()
|
| 158 |
+
|
| 159 |
+
batch_of_pairs = [
|
| 160 |
+
['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
|
| 161 |
+
['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
|
| 162 |
+
['potatoes are awesome.', 'I like to run.'],
|
| 163 |
+
['Mars is very far from earth.', 'Mars is very close.'],
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
batch = collate_tokens(
|
| 167 |
+
[roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
logprobs = roberta.predict('mnli', batch)
|
| 171 |
+
print(logprobs.argmax(dim=1))
|
| 172 |
+
# tensor([0, 2, 1, 0])
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
+
##### Using the GPU:
|
| 176 |
+
```python
|
| 177 |
+
roberta.cuda()
|
| 178 |
+
roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
## Advanced usage
|
| 182 |
+
|
| 183 |
+
#### Filling masks:
|
| 184 |
+
|
| 185 |
+
RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
|
| 186 |
+
[Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
|
| 187 |
+
```python
|
| 188 |
+
roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
|
| 189 |
+
# [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')]
|
| 190 |
+
|
| 191 |
+
roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
|
| 192 |
+
# [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')]
|
| 193 |
+
|
| 194 |
+
roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
|
| 195 |
+
# [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')]
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
#### Pronoun disambiguation (Winograd Schema Challenge):
|
| 199 |
+
|
| 200 |
+
RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model:
|
| 201 |
+
```bash
|
| 202 |
+
pip install spacy
|
| 203 |
+
python -m spacy download en_core_web_lg
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun`
|
| 207 |
+
function. The pronoun should be surrounded by square brackets (`[]`) and the
|
| 208 |
+
query referent surrounded by underscores (`_`), or left blank to return the
|
| 209 |
+
predicted candidate text directly:
|
| 210 |
+
```python
|
| 211 |
+
roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc')
|
| 212 |
+
roberta.cuda() # use the GPU (optional)
|
| 213 |
+
|
| 214 |
+
roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
|
| 215 |
+
# True
|
| 216 |
+
roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.')
|
| 217 |
+
# False
|
| 218 |
+
|
| 219 |
+
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.')
|
| 220 |
+
# 'The city councilmen'
|
| 221 |
+
roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.')
|
| 222 |
+
# 'demonstrators'
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
See the [RoBERTA Winograd Schema Challenge (WSC) README](wsc/README.md) for more details on how to train this model.
|
| 226 |
+
|
| 227 |
+
#### Extract features aligned to words:
|
| 228 |
+
|
| 229 |
+
By default RoBERTa outputs one feature vector per BPE token. You can instead
|
| 230 |
+
realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
|
| 231 |
+
with the `extract_features_aligned_to_words` method. This will compute a
|
| 232 |
+
weighted average of the BPE-level features for each word and expose them in
|
| 233 |
+
spaCy's `Token.vector` attribute:
|
| 234 |
+
```python
|
| 235 |
+
doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
|
| 236 |
+
assert len(doc) == 10
|
| 237 |
+
for tok in doc:
|
| 238 |
+
print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
|
| 239 |
+
# <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
|
| 240 |
+
# I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
|
| 241 |
+
# said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
|
| 242 |
+
# , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
|
| 243 |
+
# " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
|
| 244 |
+
# hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
|
| 245 |
+
# RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
|
| 246 |
+
# . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
|
| 247 |
+
# " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
|
| 248 |
+
# </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
#### Evaluating the `roberta.large.mnli` model:
|
| 252 |
+
|
| 253 |
+
Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
|
| 254 |
+
```python
|
| 255 |
+
label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
|
| 256 |
+
ncorrect, nsamples = 0, 0
|
| 257 |
+
roberta.cuda()
|
| 258 |
+
roberta.eval()
|
| 259 |
+
with open('glue_data/MNLI/dev_matched.tsv') as fin:
|
| 260 |
+
fin.readline()
|
| 261 |
+
for index, line in enumerate(fin):
|
| 262 |
+
tokens = line.strip().split('\t')
|
| 263 |
+
sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
|
| 264 |
+
tokens = roberta.encode(sent1, sent2)
|
| 265 |
+
prediction = roberta.predict('mnli', tokens).argmax().item()
|
| 266 |
+
prediction_label = label_map[prediction]
|
| 267 |
+
ncorrect += int(prediction_label == target)
|
| 268 |
+
nsamples += 1
|
| 269 |
+
print('| Accuracy: ', float(ncorrect)/float(nsamples))
|
| 270 |
+
# Expected output: 0.9060
|
| 271 |
+
```
|
| 272 |
+
|
| 273 |
+
## Finetuning
|
| 274 |
+
|
| 275 |
+
- [Finetuning on GLUE](README.glue.md)
|
| 276 |
+
- [Finetuning on custom classification tasks (e.g., IMDB)](README.custom_classification.md)
|
| 277 |
+
- [Finetuning on Winograd Schema Challenge (WSC)](wsc/README.md)
|
| 278 |
+
- [Finetuning on Commonsense QA (CQA)](commonsense_qa/README.md)
|
| 279 |
+
- Finetuning on SQuAD: coming soon
|
| 280 |
+
|
| 281 |
+
## Pretraining using your own data
|
| 282 |
+
|
| 283 |
+
See the [tutorial for pretraining RoBERTa using your own data](README.pretraining.md).
|
| 284 |
+
|
| 285 |
+
## Citation
|
| 286 |
+
|
| 287 |
+
```bibtex
|
| 288 |
+
@article{liu2019roberta,
|
| 289 |
+
title = {RoBERTa: A Robustly Optimized BERT Pretraining Approach},
|
| 290 |
+
author = {Yinhan Liu and Myle Ott and Naman Goyal and Jingfei Du and
|
| 291 |
+
Mandar Joshi and Danqi Chen and Omer Levy and Mike Lewis and
|
| 292 |
+
Luke Zettlemoyer and Veselin Stoyanov},
|
| 293 |
+
journal={arXiv preprint arXiv:1907.11692},
|
| 294 |
+
year = {2019},
|
| 295 |
+
}
|
| 296 |
+
```
|
fairseq-0.10.2/examples/roberta/README.pretraining.md
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pretraining RoBERTa using your own data
|
| 2 |
+
|
| 3 |
+
This tutorial will walk you through pretraining RoBERTa over your own data.
|
| 4 |
+
|
| 5 |
+
### 1) Preprocess the data
|
| 6 |
+
|
| 7 |
+
Data should be preprocessed following the [language modeling format](/examples/language_model), i.e. each document should be separated by an empty line (only useful with `--sample-break-mode complete_doc`). Lines will be concatenated as a 1D text stream during training.
|
| 8 |
+
|
| 9 |
+
We'll use the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/)
|
| 10 |
+
to demonstrate how to preprocess raw text data with the GPT-2 BPE. Of course
|
| 11 |
+
this dataset is quite small, so the resulting pretrained model will perform
|
| 12 |
+
poorly, but it gives the general idea.
|
| 13 |
+
|
| 14 |
+
First download the dataset:
|
| 15 |
+
```bash
|
| 16 |
+
wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
|
| 17 |
+
unzip wikitext-103-raw-v1.zip
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
Next encode it with the GPT-2 BPE:
|
| 21 |
+
```bash
|
| 22 |
+
mkdir -p gpt2_bpe
|
| 23 |
+
wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
| 24 |
+
wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
| 25 |
+
for SPLIT in train valid test; do \
|
| 26 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
| 27 |
+
--encoder-json gpt2_bpe/encoder.json \
|
| 28 |
+
--vocab-bpe gpt2_bpe/vocab.bpe \
|
| 29 |
+
--inputs wikitext-103-raw/wiki.${SPLIT}.raw \
|
| 30 |
+
--outputs wikitext-103-raw/wiki.${SPLIT}.bpe \
|
| 31 |
+
--keep-empty \
|
| 32 |
+
--workers 60; \
|
| 33 |
+
done
|
| 34 |
+
```
|
| 35 |
+
|
| 36 |
+
Finally preprocess/binarize the data using the GPT-2 fairseq dictionary:
|
| 37 |
+
```bash
|
| 38 |
+
wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
| 39 |
+
fairseq-preprocess \
|
| 40 |
+
--only-source \
|
| 41 |
+
--srcdict gpt2_bpe/dict.txt \
|
| 42 |
+
--trainpref wikitext-103-raw/wiki.train.bpe \
|
| 43 |
+
--validpref wikitext-103-raw/wiki.valid.bpe \
|
| 44 |
+
--testpref wikitext-103-raw/wiki.test.bpe \
|
| 45 |
+
--destdir data-bin/wikitext-103 \
|
| 46 |
+
--workers 60
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
### 2) Train RoBERTa base
|
| 50 |
+
```bash
|
| 51 |
+
TOTAL_UPDATES=125000 # Total number of training steps
|
| 52 |
+
WARMUP_UPDATES=10000 # Warmup the learning rate over this many updates
|
| 53 |
+
PEAK_LR=0.0005 # Peak learning rate, adjust as needed
|
| 54 |
+
TOKENS_PER_SAMPLE=512 # Max sequence length
|
| 55 |
+
MAX_POSITIONS=512 # Num. positional embeddings (usually same as above)
|
| 56 |
+
MAX_SENTENCES=16 # Number of sequences per batch (batch size)
|
| 57 |
+
UPDATE_FREQ=16 # Increase the batch size 16x
|
| 58 |
+
|
| 59 |
+
DATA_DIR=data-bin/wikitext-103
|
| 60 |
+
|
| 61 |
+
fairseq-train --fp16 $DATA_DIR \
|
| 62 |
+
--task masked_lm --criterion masked_lm \
|
| 63 |
+
--arch roberta_base --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \
|
| 64 |
+
--optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
|
| 65 |
+
--lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
|
| 66 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
| 67 |
+
--batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \
|
| 68 |
+
--max-update $TOTAL_UPDATES --log-format simple --log-interval 1
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**Note:** You can optionally resume training the released RoBERTa base model by
|
| 72 |
+
adding `--restore-file /path/to/roberta.base/model.pt`.
|
| 73 |
+
|
| 74 |
+
**Note:** The above command assumes training on 8x32GB V100 GPUs. Each GPU uses
|
| 75 |
+
a batch size of 16 sequences (`$MAX_SENTENCES`) and accumulates gradients to
|
| 76 |
+
further increase the batch size by 16x (`$UPDATE_FREQ`), for a total batch size
|
| 77 |
+
of 2048 sequences. If you have fewer GPUs or GPUs with less memory you may need
|
| 78 |
+
to reduce `$MAX_SENTENCES` and increase `$UPDATE_FREQ` to compensate.
|
| 79 |
+
Alternatively if you have more GPUs you can decrease `$UPDATE_FREQ` accordingly
|
| 80 |
+
to increase training speed.
|
| 81 |
+
|
| 82 |
+
**Note:** The learning rate and batch size are tightly connected and need to be
|
| 83 |
+
adjusted together. We generally recommend increasing the learning rate as you
|
| 84 |
+
increase the batch size according to the following table (although it's also
|
| 85 |
+
dataset dependent, so don't rely on the following values too closely):
|
| 86 |
+
|
| 87 |
+
batch size | peak learning rate
|
| 88 |
+
---|---
|
| 89 |
+
256 | 0.0001
|
| 90 |
+
2048 | 0.0005
|
| 91 |
+
8192 | 0.0007
|
| 92 |
+
|
| 93 |
+
### 3) Load your pretrained model
|
| 94 |
+
```python
|
| 95 |
+
from fairseq.models.roberta import RobertaModel
|
| 96 |
+
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'path/to/data')
|
| 97 |
+
assert isinstance(roberta.model, torch.nn.Module)
|
| 98 |
+
```
|
fairseq-0.10.2/examples/roberta/README.race.md
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finetuning RoBERTa on RACE tasks
|
| 2 |
+
|
| 3 |
+
### 1) Download the data from RACE website (http://www.cs.cmu.edu/~glai1/data/race/)
|
| 4 |
+
|
| 5 |
+
### 2) Preprocess RACE data:
|
| 6 |
+
```bash
|
| 7 |
+
python ./examples/roberta/preprocess_RACE.py --input-dir <input-dir> --output-dir <extracted-data-dir>
|
| 8 |
+
./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir>
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
### 3) Fine-tuning on RACE:
|
| 12 |
+
|
| 13 |
+
```bash
|
| 14 |
+
MAX_EPOCH=5 # Number of training epochs.
|
| 15 |
+
LR=1e-05 # Peak LR for fixed LR scheduler.
|
| 16 |
+
NUM_CLASSES=4
|
| 17 |
+
MAX_SENTENCES=1 # Batch size per GPU.
|
| 18 |
+
UPDATE_FREQ=8 # Accumulate gradients to simulate training on 8 GPUs.
|
| 19 |
+
DATA_DIR=/path/to/race-output-dir
|
| 20 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
| 21 |
+
|
| 22 |
+
CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=no_c10d \
|
| 23 |
+
--restore-file $ROBERTA_PATH \
|
| 24 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 25 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
| 26 |
+
--task sentence_ranking \
|
| 27 |
+
--num-classes $NUM_CLASSES \
|
| 28 |
+
--init-token 0 --separator-token 2 \
|
| 29 |
+
--max-option-length 128 \
|
| 30 |
+
--max-positions 512 \
|
| 31 |
+
--shorten-method "truncate" \
|
| 32 |
+
--arch roberta_large \
|
| 33 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
| 34 |
+
--criterion sentence_ranking \
|
| 35 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
|
| 36 |
+
--clip-norm 0.0 \
|
| 37 |
+
--lr-scheduler fixed --lr $LR \
|
| 38 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
| 39 |
+
--batch-size $MAX_SENTENCES \
|
| 40 |
+
--required-batch-size-multiple 1 \
|
| 41 |
+
--update-freq $UPDATE_FREQ \
|
| 42 |
+
--max-epoch $MAX_EPOCH
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
**Note:**
|
| 46 |
+
|
| 47 |
+
a) As contexts in RACE are relatively long, we are using smaller batch size per GPU while increasing update-freq to achieve larger effective batch size.
|
| 48 |
+
|
| 49 |
+
b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
|
| 50 |
+
|
| 51 |
+
c) The setting in above command is based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
|
| 52 |
+
|
| 53 |
+
### 4) Evaluation:
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
DATA_DIR=/path/to/race-output-dir # data directory used during training
|
| 57 |
+
MODEL_PATH=/path/to/checkpoint_best.pt # path to the finetuned model checkpoint
|
| 58 |
+
PREDS_OUT=preds.tsv # output file path to save prediction
|
| 59 |
+
TEST_SPLIT=test # can be test (Middle) or test1 (High)
|
| 60 |
+
fairseq-validate \
|
| 61 |
+
$DATA_DIR \
|
| 62 |
+
--valid-subset $TEST_SPLIT \
|
| 63 |
+
--path $MODEL_PATH \
|
| 64 |
+
--batch-size 1 \
|
| 65 |
+
--task sentence_ranking \
|
| 66 |
+
--criterion sentence_ranking \
|
| 67 |
+
--save-predictions $PREDS_OUT
|
| 68 |
+
```
|
fairseq-0.10.2/examples/roberta/multiprocessing_bpe_encoder.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import contextlib
|
| 10 |
+
import sys
|
| 11 |
+
from collections import Counter
|
| 12 |
+
from multiprocessing import Pool
|
| 13 |
+
|
| 14 |
+
from fairseq.data.encoders.gpt2_bpe import get_encoder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
"""
|
| 19 |
+
Helper script to encode raw text with the GPT-2 BPE using multiple processes.
|
| 20 |
+
|
| 21 |
+
The encoder.json and vocab.bpe files can be obtained here:
|
| 22 |
+
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
|
| 23 |
+
- https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
|
| 24 |
+
"""
|
| 25 |
+
parser = argparse.ArgumentParser()
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--encoder-json",
|
| 28 |
+
help="path to encoder.json",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--vocab-bpe",
|
| 32 |
+
type=str,
|
| 33 |
+
help="path to vocab.bpe",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--inputs",
|
| 37 |
+
nargs="+",
|
| 38 |
+
default=["-"],
|
| 39 |
+
help="input files to filter/encode",
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--outputs",
|
| 43 |
+
nargs="+",
|
| 44 |
+
default=["-"],
|
| 45 |
+
help="path to save encoded outputs",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--keep-empty",
|
| 49 |
+
action="store_true",
|
| 50 |
+
help="keep empty lines",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument("--workers", type=int, default=20)
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
|
| 55 |
+
assert len(args.inputs) == len(
|
| 56 |
+
args.outputs
|
| 57 |
+
), "number of input and output paths should match"
|
| 58 |
+
|
| 59 |
+
with contextlib.ExitStack() as stack:
|
| 60 |
+
inputs = [
|
| 61 |
+
stack.enter_context(open(input, "r", encoding="utf-8"))
|
| 62 |
+
if input != "-"
|
| 63 |
+
else sys.stdin
|
| 64 |
+
for input in args.inputs
|
| 65 |
+
]
|
| 66 |
+
outputs = [
|
| 67 |
+
stack.enter_context(open(output, "w", encoding="utf-8"))
|
| 68 |
+
if output != "-"
|
| 69 |
+
else sys.stdout
|
| 70 |
+
for output in args.outputs
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
encoder = MultiprocessingEncoder(args)
|
| 74 |
+
pool = Pool(args.workers, initializer=encoder.initializer)
|
| 75 |
+
encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
|
| 76 |
+
|
| 77 |
+
stats = Counter()
|
| 78 |
+
for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
|
| 79 |
+
if filt == "PASS":
|
| 80 |
+
for enc_line, output_h in zip(enc_lines, outputs):
|
| 81 |
+
print(enc_line, file=output_h)
|
| 82 |
+
else:
|
| 83 |
+
stats["num_filtered_" + filt] += 1
|
| 84 |
+
if i % 10000 == 0:
|
| 85 |
+
print("processed {} lines".format(i), file=sys.stderr)
|
| 86 |
+
|
| 87 |
+
for k, v in stats.most_common():
|
| 88 |
+
print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class MultiprocessingEncoder(object):
|
| 92 |
+
def __init__(self, args):
|
| 93 |
+
self.args = args
|
| 94 |
+
|
| 95 |
+
def initializer(self):
|
| 96 |
+
global bpe
|
| 97 |
+
bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
|
| 98 |
+
|
| 99 |
+
def encode(self, line):
|
| 100 |
+
global bpe
|
| 101 |
+
ids = bpe.encode(line)
|
| 102 |
+
return list(map(str, ids))
|
| 103 |
+
|
| 104 |
+
def decode(self, tokens):
|
| 105 |
+
global bpe
|
| 106 |
+
return bpe.decode(tokens)
|
| 107 |
+
|
| 108 |
+
def encode_lines(self, lines):
|
| 109 |
+
"""
|
| 110 |
+
Encode a set of lines. All lines will be encoded together.
|
| 111 |
+
"""
|
| 112 |
+
enc_lines = []
|
| 113 |
+
for line in lines:
|
| 114 |
+
line = line.strip()
|
| 115 |
+
if len(line) == 0 and not self.args.keep_empty:
|
| 116 |
+
return ["EMPTY", None]
|
| 117 |
+
tokens = self.encode(line)
|
| 118 |
+
enc_lines.append(" ".join(tokens))
|
| 119 |
+
return ["PASS", enc_lines]
|
| 120 |
+
|
| 121 |
+
def decode_lines(self, lines):
|
| 122 |
+
dec_lines = []
|
| 123 |
+
for line in lines:
|
| 124 |
+
tokens = map(int, line.strip().split())
|
| 125 |
+
dec_lines.append(self.decode(tokens))
|
| 126 |
+
return ["PASS", dec_lines]
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
main()
|
fairseq-0.10.2/examples/roberta/preprocess_GLUE_tasks.sh
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
|
| 9 |
+
if [[ $# -ne 2 ]]; then
|
| 10 |
+
echo "Run as following:"
|
| 11 |
+
echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
GLUE_DATA_FOLDER=$1
|
| 16 |
+
|
| 17 |
+
# download bpe encoder.json, vocabulary and fairseq dictionary
|
| 18 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
| 19 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
| 20 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
| 21 |
+
|
| 22 |
+
TASKS=$2 # QQP
|
| 23 |
+
|
| 24 |
+
if [ "$TASKS" = "ALL" ]
|
| 25 |
+
then
|
| 26 |
+
TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
for TASK in $TASKS
|
| 30 |
+
do
|
| 31 |
+
echo "Preprocessing $TASK"
|
| 32 |
+
|
| 33 |
+
TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"
|
| 34 |
+
echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER"
|
| 35 |
+
|
| 36 |
+
SPLITS="train dev test"
|
| 37 |
+
INPUT_COUNT=2
|
| 38 |
+
if [ "$TASK" = "QQP" ]
|
| 39 |
+
then
|
| 40 |
+
INPUT_COLUMNS=( 4 5 )
|
| 41 |
+
TEST_INPUT_COLUMNS=( 2 3 )
|
| 42 |
+
LABEL_COLUMN=6
|
| 43 |
+
elif [ "$TASK" = "MNLI" ]
|
| 44 |
+
then
|
| 45 |
+
SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
|
| 46 |
+
INPUT_COLUMNS=( 9 10 )
|
| 47 |
+
TEST_INPUT_COLUMNS=( 9 10 )
|
| 48 |
+
DEV_LABEL_COLUMN=16
|
| 49 |
+
LABEL_COLUMN=12
|
| 50 |
+
elif [ "$TASK" = "QNLI" ]
|
| 51 |
+
then
|
| 52 |
+
INPUT_COLUMNS=( 2 3 )
|
| 53 |
+
TEST_INPUT_COLUMNS=( 2 3 )
|
| 54 |
+
LABEL_COLUMN=4
|
| 55 |
+
elif [ "$TASK" = "MRPC" ]
|
| 56 |
+
then
|
| 57 |
+
INPUT_COLUMNS=( 4 5 )
|
| 58 |
+
TEST_INPUT_COLUMNS=( 4 5 )
|
| 59 |
+
LABEL_COLUMN=1
|
| 60 |
+
elif [ "$TASK" = "RTE" ]
|
| 61 |
+
then
|
| 62 |
+
INPUT_COLUMNS=( 2 3 )
|
| 63 |
+
TEST_INPUT_COLUMNS=( 2 3 )
|
| 64 |
+
LABEL_COLUMN=4
|
| 65 |
+
elif [ "$TASK" = "STS-B" ]
|
| 66 |
+
then
|
| 67 |
+
INPUT_COLUMNS=( 8 9 )
|
| 68 |
+
TEST_INPUT_COLUMNS=( 8 9 )
|
| 69 |
+
LABEL_COLUMN=10
|
| 70 |
+
# Following are single sentence tasks.
|
| 71 |
+
elif [ "$TASK" = "SST-2" ]
|
| 72 |
+
then
|
| 73 |
+
INPUT_COLUMNS=( 1 )
|
| 74 |
+
TEST_INPUT_COLUMNS=( 2 )
|
| 75 |
+
LABEL_COLUMN=2
|
| 76 |
+
INPUT_COUNT=1
|
| 77 |
+
elif [ "$TASK" = "CoLA" ]
|
| 78 |
+
then
|
| 79 |
+
INPUT_COLUMNS=( 4 )
|
| 80 |
+
TEST_INPUT_COLUMNS=( 2 )
|
| 81 |
+
LABEL_COLUMN=2
|
| 82 |
+
INPUT_COUNT=1
|
| 83 |
+
fi
|
| 84 |
+
|
| 85 |
+
# Strip out header and filter lines that don't have expected number of fields.
|
| 86 |
+
rm -rf "$TASK_DATA_FOLDER/processed"
|
| 87 |
+
mkdir -p "$TASK_DATA_FOLDER/processed"
|
| 88 |
+
for SPLIT in $SPLITS
|
| 89 |
+
do
|
| 90 |
+
# CoLA train and dev doesn't have header.
|
| 91 |
+
if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]
|
| 92 |
+
then
|
| 93 |
+
cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
|
| 94 |
+
else
|
| 95 |
+
tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
|
| 96 |
+
fi
|
| 97 |
+
|
| 98 |
+
# Remove unformatted lines from train and dev files for QQP dataset.
|
| 99 |
+
if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]
|
| 100 |
+
then
|
| 101 |
+
awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
|
| 102 |
+
else
|
| 103 |
+
cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
|
| 104 |
+
fi
|
| 105 |
+
rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
|
| 106 |
+
done
|
| 107 |
+
|
| 108 |
+
# Split into input0, input1 and label
|
| 109 |
+
for SPLIT in $SPLITS
|
| 110 |
+
do
|
| 111 |
+
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
|
| 112 |
+
do
|
| 113 |
+
if [[ "$SPLIT" != test* ]]
|
| 114 |
+
then
|
| 115 |
+
COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
|
| 116 |
+
else
|
| 117 |
+
COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
|
| 118 |
+
fi
|
| 119 |
+
cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";
|
| 120 |
+
done
|
| 121 |
+
|
| 122 |
+
if [[ "$SPLIT" != test* ]]
|
| 123 |
+
then
|
| 124 |
+
if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]
|
| 125 |
+
then
|
| 126 |
+
cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
|
| 127 |
+
else
|
| 128 |
+
cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
|
| 129 |
+
fi
|
| 130 |
+
fi
|
| 131 |
+
|
| 132 |
+
# BPE encode.
|
| 133 |
+
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
|
| 134 |
+
do
|
| 135 |
+
LANG="input$INPUT_TYPE"
|
| 136 |
+
echo "BPE encoding $SPLIT/$LANG"
|
| 137 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
| 138 |
+
--encoder-json encoder.json \
|
| 139 |
+
--vocab-bpe vocab.bpe \
|
| 140 |
+
--inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \
|
| 141 |
+
--outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \
|
| 142 |
+
--workers 60 \
|
| 143 |
+
--keep-empty;
|
| 144 |
+
done
|
| 145 |
+
done
|
| 146 |
+
|
| 147 |
+
# Remove output directory.
|
| 148 |
+
rm -rf "$TASK-bin"
|
| 149 |
+
|
| 150 |
+
DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"
|
| 151 |
+
TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"
|
| 152 |
+
if [ "$TASK" = "MNLI" ]
|
| 153 |
+
then
|
| 154 |
+
DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"
|
| 155 |
+
TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
# Run fairseq preprocessing:
|
| 159 |
+
for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
|
| 160 |
+
do
|
| 161 |
+
LANG="input$INPUT_TYPE"
|
| 162 |
+
fairseq-preprocess \
|
| 163 |
+
--only-source \
|
| 164 |
+
--trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \
|
| 165 |
+
--validpref "${DEVPREF//LANG/$LANG}" \
|
| 166 |
+
--testpref "${TESTPREF//LANG/$LANG}" \
|
| 167 |
+
--destdir "$TASK-bin/$LANG" \
|
| 168 |
+
--workers 60 \
|
| 169 |
+
--srcdict dict.txt;
|
| 170 |
+
done
|
| 171 |
+
if [[ "$TASK" != "STS-B" ]]
|
| 172 |
+
then
|
| 173 |
+
fairseq-preprocess \
|
| 174 |
+
--only-source \
|
| 175 |
+
--trainpref "$TASK_DATA_FOLDER/processed/train.label" \
|
| 176 |
+
--validpref "${DEVPREF//LANG/label}" \
|
| 177 |
+
--destdir "$TASK-bin/label" \
|
| 178 |
+
--workers 60;
|
| 179 |
+
else
|
| 180 |
+
# For STS-B output range is converted to be between: [0.0, 1.0]
|
| 181 |
+
mkdir -p "$TASK-bin/label"
|
| 182 |
+
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"
|
| 183 |
+
awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"
|
| 184 |
+
fi
|
| 185 |
+
done
|
fairseq-0.10.2/examples/roberta/preprocess_RACE.sh
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# data should be downloaded and processed with reprocess_RACE.py
|
| 9 |
+
if [[ $# -ne 2 ]]; then
|
| 10 |
+
echo "Run as following:"
|
| 11 |
+
echo "./examples/roberta/preprocess_RACE.sh <race_data_folder> <output_folder>"
|
| 12 |
+
exit 1
|
| 13 |
+
fi
|
| 14 |
+
|
| 15 |
+
RACE_DATA_FOLDER=$1
|
| 16 |
+
OUT_DATA_FOLDER=$2
|
| 17 |
+
|
| 18 |
+
# download bpe encoder.json, vocabulary and fairseq dictionary
|
| 19 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
| 20 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
| 21 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
| 22 |
+
|
| 23 |
+
SPLITS="train dev test-middle test-high"
|
| 24 |
+
INPUT_TYPES="input0 input1 input2 input3 input4"
|
| 25 |
+
for INPUT_TYPE in $INPUT_TYPES
|
| 26 |
+
do
|
| 27 |
+
for SPLIT in $SPLITS
|
| 28 |
+
do
|
| 29 |
+
echo "BPE encoding $SPLIT/$INPUT_TYPE"
|
| 30 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
| 31 |
+
--encoder-json encoder.json \
|
| 32 |
+
--vocab-bpe vocab.bpe \
|
| 33 |
+
--inputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE" \
|
| 34 |
+
--outputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE.bpe" \
|
| 35 |
+
--workers 10 \
|
| 36 |
+
--keep-empty;
|
| 37 |
+
|
| 38 |
+
done
|
| 39 |
+
done
|
| 40 |
+
|
| 41 |
+
for INPUT_TYPE in $INPUT_TYPES
|
| 42 |
+
do
|
| 43 |
+
LANG="input$INPUT_TYPE"
|
| 44 |
+
fairseq-preprocess \
|
| 45 |
+
--only-source \
|
| 46 |
+
--trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
|
| 47 |
+
--validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
|
| 48 |
+
--testpref "$RACE_DATA_FOLDER/test-middle.$INPUT_TYPE.bpe,$RACE_DATA_FOLDER/test-high.$INPUT_TYPE.bpe" \
|
| 49 |
+
--destdir "$OUT_DATA_FOLDER/$INPUT_TYPE" \
|
| 50 |
+
--workers 10 \
|
| 51 |
+
--srcdict dict.txt;
|
| 52 |
+
done
|
| 53 |
+
|
| 54 |
+
rm -rf "$OUT_DATA_FOLDER/label"
|
| 55 |
+
mkdir -p "$OUT_DATA_FOLDER/label"
|
| 56 |
+
cp "$RACE_DATA_FOLDER/train.label" "$OUT_DATA_FOLDER/label/"
|
| 57 |
+
cp "$RACE_DATA_FOLDER/dev.label" "$OUT_DATA_FOLDER/label/valid.label"
|
| 58 |
+
cp "$RACE_DATA_FOLDER/test-middle.label" "$OUT_DATA_FOLDER/label/test.label"
|
| 59 |
+
cp "$RACE_DATA_FOLDER/test-high.label" "$OUT_DATA_FOLDER/label/test1.label"
|
fairseq-0.10.2/examples/speech_recognition/README.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Speech Recognition
|
| 2 |
+
`examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
## Additional dependencies
|
| 6 |
+
On top of main fairseq dependencies there are couple more additional requirements.
|
| 7 |
+
|
| 8 |
+
1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
|
| 9 |
+
2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
|
| 10 |
+
3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
|
| 11 |
+
|
| 12 |
+
## Preparing librispeech data
|
| 13 |
+
```
|
| 14 |
+
./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
## Training librispeech data
|
| 18 |
+
```
|
| 19 |
+
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Inference for librispeech
|
| 23 |
+
`$SET` can be `test_clean` or `test_other`
|
| 24 |
+
Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
|
| 25 |
+
```
|
| 26 |
+
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Inference for librispeech
|
| 30 |
+
```
|
| 31 |
+
sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
|
| 32 |
+
```
|
| 33 |
+
`Sum/Avg` row from first table of the report has WER
|
| 34 |
+
|
| 35 |
+
## Using wav2letter components
|
| 36 |
+
[wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
|
| 37 |
+
|
| 38 |
+
* AutoSegmentationCriterion (ASG)
|
| 39 |
+
* wav2letter-style Conv/GLU model
|
| 40 |
+
* wav2letter's beam search decoder
|
| 41 |
+
|
| 42 |
+
To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
|
| 43 |
+
|
| 44 |
+
To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
|
| 45 |
+
```
|
| 46 |
+
# additional prerequisites - use equivalents for your distro
|
| 47 |
+
sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
|
| 48 |
+
# install KenLM from source
|
| 49 |
+
git clone https://github.com/kpu/kenlm.git
|
| 50 |
+
cd kenlm
|
| 51 |
+
mkdir -p build && cd build
|
| 52 |
+
cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
|
| 53 |
+
make -j16
|
| 54 |
+
cd ..
|
| 55 |
+
export KENLM_ROOT_DIR=$(pwd)
|
| 56 |
+
cd ..
|
| 57 |
+
# install wav2letter python bindings
|
| 58 |
+
git clone https://github.com/facebookresearch/wav2letter.git
|
| 59 |
+
cd wav2letter/bindings/python
|
| 60 |
+
# make sure your python environment is active at this point
|
| 61 |
+
pip install torch packaging
|
| 62 |
+
pip install -e .
|
| 63 |
+
# try some examples to verify installation succeeded
|
| 64 |
+
python ./examples/criterion_example.py
|
| 65 |
+
python ./examples/decoder_example.py ../../src/decoder/test
|
| 66 |
+
python ./examples/feature_example.py ../../src/feature/test/data
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
|
| 70 |
+
Training command:
|
| 71 |
+
```
|
| 72 |
+
python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
|
| 76 |
+
|
| 77 |
+
## Inference for librispeech (wav2letter decoder, n-gram LM)
|
| 78 |
+
Inference command:
|
| 79 |
+
```
|
| 80 |
+
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
`$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
|
| 84 |
+
```
|
| 85 |
+
doorbell D O 1 R B E L 1 ▁
|
| 86 |
+
```
|
| 87 |
+
For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
|
| 88 |
+
```
|
| 89 |
+
doorbell ▁DOOR BE LL
|
| 90 |
+
doorbell ▁DOOR B E LL
|
| 91 |
+
doorbell ▁DO OR BE LL
|
| 92 |
+
doorbell ▁DOOR B EL L
|
| 93 |
+
doorbell ▁DOOR BE L L
|
| 94 |
+
doorbell ▁DO OR B E LL
|
| 95 |
+
doorbell ▁DOOR B E L L
|
| 96 |
+
doorbell ▁DO OR B EL L
|
| 97 |
+
doorbell ▁DO O R BE LL
|
| 98 |
+
doorbell ▁DO OR BE L L
|
| 99 |
+
```
|
| 100 |
+
Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
|
| 101 |
+
|
| 102 |
+
## Inference for librispeech (wav2letter decoder, viterbi only)
|
| 103 |
+
Inference command:
|
| 104 |
+
```
|
| 105 |
+
python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
|
| 106 |
+
```
|
fairseq-0.10.2/examples/speech_recognition/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import criterions, models, tasks # noqa
|
fairseq-0.10.2/examples/speech_recognition/criterions/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# ASG loss requires wav2letter
|
| 6 |
+
files_to_skip = set()
|
| 7 |
+
try:
|
| 8 |
+
import wav2letter
|
| 9 |
+
except ImportError:
|
| 10 |
+
files_to_skip.add("ASG_loss.py")
|
| 11 |
+
|
| 12 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 13 |
+
if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
|
| 14 |
+
criterion_name = file[: file.find(".py")]
|
| 15 |
+
importlib.import_module(
|
| 16 |
+
"examples.speech_recognition.criterions." + criterion_name
|
| 17 |
+
)
|
fairseq-0.10.2/examples/speech_recognition/criterions/cross_entropy_acc.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from fairseq import utils
|
| 14 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@register_criterion("cross_entropy_acc")
|
| 18 |
+
class CrossEntropyWithAccCriterion(FairseqCriterion):
|
| 19 |
+
def __init__(self, task, sentence_avg):
|
| 20 |
+
super().__init__(task)
|
| 21 |
+
self.sentence_avg = sentence_avg
|
| 22 |
+
|
| 23 |
+
def compute_loss(self, model, net_output, target, reduction, log_probs):
|
| 24 |
+
# N, T -> N * T
|
| 25 |
+
target = target.view(-1)
|
| 26 |
+
lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
|
| 27 |
+
if not hasattr(lprobs, "batch_first"):
|
| 28 |
+
logging.warning(
|
| 29 |
+
"ERROR: we need to know whether "
|
| 30 |
+
"batch first for the net output; "
|
| 31 |
+
"you need to set batch_first attribute for the return value of "
|
| 32 |
+
"model.get_normalized_probs. Now, we assume this is true, but "
|
| 33 |
+
"in the future, we will raise exception instead. "
|
| 34 |
+
)
|
| 35 |
+
batch_first = getattr(lprobs, "batch_first", True)
|
| 36 |
+
if not batch_first:
|
| 37 |
+
lprobs = lprobs.transpose(0, 1)
|
| 38 |
+
|
| 39 |
+
# N, T, D -> N * T, D
|
| 40 |
+
lprobs = lprobs.view(-1, lprobs.size(-1))
|
| 41 |
+
loss = F.nll_loss(
|
| 42 |
+
lprobs, target, ignore_index=self.padding_idx, reduction=reduction
|
| 43 |
+
)
|
| 44 |
+
return lprobs, loss
|
| 45 |
+
|
| 46 |
+
def get_logging_output(self, sample, target, lprobs, loss):
|
| 47 |
+
target = target.view(-1)
|
| 48 |
+
mask = target != self.padding_idx
|
| 49 |
+
correct = torch.sum(
|
| 50 |
+
lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
|
| 51 |
+
)
|
| 52 |
+
total = torch.sum(mask)
|
| 53 |
+
sample_size = (
|
| 54 |
+
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
logging_output = {
|
| 58 |
+
"loss": utils.item(loss.data), # * sample['ntokens'],
|
| 59 |
+
"ntokens": sample["ntokens"],
|
| 60 |
+
"nsentences": sample["target"].size(0),
|
| 61 |
+
"sample_size": sample_size,
|
| 62 |
+
"correct": utils.item(correct.data),
|
| 63 |
+
"total": utils.item(total.data),
|
| 64 |
+
"nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
return sample_size, logging_output
|
| 68 |
+
|
| 69 |
+
def forward(self, model, sample, reduction="sum", log_probs=True):
|
| 70 |
+
"""Computes the cross entropy with accuracy metric for the given sample.
|
| 71 |
+
|
| 72 |
+
This is similar to CrossEntropyCriterion in fairseq, but also
|
| 73 |
+
computes accuracy metrics as part of logging
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
logprobs (Torch.tensor) of shape N, T, D i.e.
|
| 77 |
+
batchsize, timesteps, dimensions
|
| 78 |
+
targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
tuple: With three elements:
|
| 82 |
+
1) the loss
|
| 83 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 84 |
+
3) logging outputs to display while training
|
| 85 |
+
|
| 86 |
+
TODO:
|
| 87 |
+
* Currently this Criterion will only work with LSTMEncoderModels or
|
| 88 |
+
FairseqModels which have decoder, or Models which return TorchTensor
|
| 89 |
+
as net_output.
|
| 90 |
+
We need to make a change to support all FairseqEncoder models.
|
| 91 |
+
"""
|
| 92 |
+
net_output = model(**sample["net_input"])
|
| 93 |
+
target = model.get_targets(sample, net_output)
|
| 94 |
+
lprobs, loss = self.compute_loss(
|
| 95 |
+
model, net_output, target, reduction, log_probs
|
| 96 |
+
)
|
| 97 |
+
sample_size, logging_output = self.get_logging_output(
|
| 98 |
+
sample, target, lprobs, loss
|
| 99 |
+
)
|
| 100 |
+
return loss, sample_size, logging_output
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def aggregate_logging_outputs(logging_outputs):
|
| 104 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 105 |
+
correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
|
| 106 |
+
total_sum = sum(log.get("total", 0) for log in logging_outputs)
|
| 107 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 108 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 109 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 110 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 111 |
+
nframes = sum(log.get("nframes", 0) for log in logging_outputs)
|
| 112 |
+
agg_output = {
|
| 113 |
+
"loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
|
| 114 |
+
# if args.sentence_avg, then sample_size is nsentences, then loss
|
| 115 |
+
# is per-sentence loss; else sample_size is ntokens, the loss
|
| 116 |
+
# becomes per-output token loss
|
| 117 |
+
"ntokens": ntokens,
|
| 118 |
+
"nsentences": nsentences,
|
| 119 |
+
"nframes": nframes,
|
| 120 |
+
"sample_size": sample_size,
|
| 121 |
+
"acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
|
| 122 |
+
"correct": correct_sum,
|
| 123 |
+
"total": total_sum,
|
| 124 |
+
# total is the number of validate tokens
|
| 125 |
+
}
|
| 126 |
+
if sample_size != ntokens:
|
| 127 |
+
agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
|
| 128 |
+
# loss: per output token loss
|
| 129 |
+
# nll_loss: per sentence loss
|
| 130 |
+
return agg_output
|
fairseq-0.10.2/examples/speech_recognition/datasets/asr_prep_json.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 8 |
+
|
| 9 |
+
import argparse
|
| 10 |
+
import concurrent.futures
|
| 11 |
+
import json
|
| 12 |
+
import multiprocessing
|
| 13 |
+
import os
|
| 14 |
+
from collections import namedtuple
|
| 15 |
+
from itertools import chain
|
| 16 |
+
|
| 17 |
+
import sentencepiece as spm
|
| 18 |
+
from fairseq.data import Dictionary
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
MILLISECONDS_TO_SECONDS = 0.001
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
|
| 25 |
+
import torchaudio
|
| 26 |
+
|
| 27 |
+
input = {}
|
| 28 |
+
output = {}
|
| 29 |
+
si, ei = torchaudio.info(aud_path)
|
| 30 |
+
input["length_ms"] = int(
|
| 31 |
+
si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
|
| 32 |
+
)
|
| 33 |
+
input["path"] = aud_path
|
| 34 |
+
|
| 35 |
+
token = " ".join(sp.EncodeAsPieces(lable))
|
| 36 |
+
ids = tgt_dict.encode_line(token, append_eos=False)
|
| 37 |
+
output["text"] = lable
|
| 38 |
+
output["token"] = token
|
| 39 |
+
output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
|
| 40 |
+
return {utt_id: {"input": input, "output": output}}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main():
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument(
|
| 46 |
+
"--audio-dirs",
|
| 47 |
+
nargs="+",
|
| 48 |
+
default=["-"],
|
| 49 |
+
required=True,
|
| 50 |
+
help="input directories with audio files",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--labels",
|
| 54 |
+
required=True,
|
| 55 |
+
help="aggregated input labels with format <ID LABEL> per line",
|
| 56 |
+
type=argparse.FileType("r", encoding="UTF-8"),
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--spm-model",
|
| 60 |
+
required=True,
|
| 61 |
+
help="sentencepiece model to use for encoding",
|
| 62 |
+
type=argparse.FileType("r", encoding="UTF-8"),
|
| 63 |
+
)
|
| 64 |
+
parser.add_argument(
|
| 65 |
+
"--dictionary",
|
| 66 |
+
required=True,
|
| 67 |
+
help="file to load fairseq dictionary from",
|
| 68 |
+
type=argparse.FileType("r", encoding="UTF-8"),
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--output",
|
| 73 |
+
required=True,
|
| 74 |
+
type=argparse.FileType("w"),
|
| 75 |
+
help="path to save json output",
|
| 76 |
+
)
|
| 77 |
+
args = parser.parse_args()
|
| 78 |
+
|
| 79 |
+
sp = spm.SentencePieceProcessor()
|
| 80 |
+
sp.Load(args.spm_model.name)
|
| 81 |
+
|
| 82 |
+
tgt_dict = Dictionary.load(args.dictionary)
|
| 83 |
+
|
| 84 |
+
labels = {}
|
| 85 |
+
for line in args.labels:
|
| 86 |
+
(utt_id, label) = line.split(" ", 1)
|
| 87 |
+
labels[utt_id] = label
|
| 88 |
+
if len(labels) == 0:
|
| 89 |
+
raise Exception("No labels found in ", args.labels_path)
|
| 90 |
+
|
| 91 |
+
Sample = namedtuple("Sample", "aud_path utt_id")
|
| 92 |
+
samples = []
|
| 93 |
+
for path, _, files in chain.from_iterable(
|
| 94 |
+
os.walk(path) for path in args.audio_dirs
|
| 95 |
+
):
|
| 96 |
+
for f in files:
|
| 97 |
+
if f.endswith(args.audio_format):
|
| 98 |
+
if len(os.path.splitext(f)) != 2:
|
| 99 |
+
raise Exception("Expect <utt_id.extension> file name. Got: ", f)
|
| 100 |
+
utt_id = os.path.splitext(f)[0]
|
| 101 |
+
if utt_id not in labels:
|
| 102 |
+
continue
|
| 103 |
+
samples.append(Sample(os.path.join(path, f), utt_id))
|
| 104 |
+
|
| 105 |
+
utts = {}
|
| 106 |
+
num_cpu = multiprocessing.cpu_count()
|
| 107 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
|
| 108 |
+
future_to_sample = {
|
| 109 |
+
executor.submit(
|
| 110 |
+
process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
|
| 111 |
+
): s
|
| 112 |
+
for s in samples
|
| 113 |
+
}
|
| 114 |
+
for future in concurrent.futures.as_completed(future_to_sample):
|
| 115 |
+
try:
|
| 116 |
+
data = future.result()
|
| 117 |
+
except Exception as exc:
|
| 118 |
+
print("generated an exception: ", exc)
|
| 119 |
+
else:
|
| 120 |
+
utts.update(data)
|
| 121 |
+
json.dump({"utts": utts}, args.output, indent=4)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
fairseq-0.10.2/examples/speech_recognition/datasets/prepare-librispeech.sh
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Prepare librispeech dataset
|
| 8 |
+
|
| 9 |
+
base_url=www.openslr.org/resources/12
|
| 10 |
+
train_dir=train_960
|
| 11 |
+
|
| 12 |
+
if [ "$#" -ne 2 ]; then
|
| 13 |
+
echo "Usage: $0 <download_dir> <out_dir>"
|
| 14 |
+
echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
|
| 15 |
+
exit 1
|
| 16 |
+
fi
|
| 17 |
+
|
| 18 |
+
download_dir=${1%/}
|
| 19 |
+
out_dir=${2%/}
|
| 20 |
+
|
| 21 |
+
fairseq_root=~/fairseq-py/
|
| 22 |
+
mkdir -p ${out_dir}
|
| 23 |
+
cd ${out_dir} || exit
|
| 24 |
+
|
| 25 |
+
nbpe=5000
|
| 26 |
+
bpemode=unigram
|
| 27 |
+
|
| 28 |
+
if [ ! -d "$fairseq_root" ]; then
|
| 29 |
+
echo "$0: Please set correct fairseq_root"
|
| 30 |
+
exit 1
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
echo "Data Download"
|
| 34 |
+
for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
|
| 35 |
+
url=$base_url/$part.tar.gz
|
| 36 |
+
if ! wget -P $download_dir $url; then
|
| 37 |
+
echo "$0: wget failed for $url"
|
| 38 |
+
exit 1
|
| 39 |
+
fi
|
| 40 |
+
if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
|
| 41 |
+
echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
|
| 42 |
+
exit 1
|
| 43 |
+
fi
|
| 44 |
+
done
|
| 45 |
+
|
| 46 |
+
echo "Merge all train packs into one"
|
| 47 |
+
mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
|
| 48 |
+
for part in train-clean-100 train-clean-360 train-other-500; do
|
| 49 |
+
mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
|
| 50 |
+
done
|
| 51 |
+
echo "Merge train text"
|
| 52 |
+
find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
|
| 53 |
+
|
| 54 |
+
# Use combined dev-clean and dev-other as validation set
|
| 55 |
+
find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
|
| 56 |
+
find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
|
| 57 |
+
find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
|
| 61 |
+
encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
|
| 62 |
+
fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
|
| 63 |
+
bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
|
| 64 |
+
echo "dictionary: ${dict}"
|
| 65 |
+
echo "Dictionary preparation"
|
| 66 |
+
mkdir -p data/lang_char/
|
| 67 |
+
echo "<unk> 3" > ${dict}
|
| 68 |
+
echo "</s> 2" >> ${dict}
|
| 69 |
+
echo "<pad> 1" >> ${dict}
|
| 70 |
+
cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
|
| 71 |
+
spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
|
| 72 |
+
spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
|
| 73 |
+
cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
|
| 74 |
+
cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
|
| 75 |
+
wc -l ${dict}
|
| 76 |
+
|
| 77 |
+
echo "Prepare train and test jsons"
|
| 78 |
+
for part in train_960 test-other test-clean; do
|
| 79 |
+
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
|
| 80 |
+
done
|
| 81 |
+
# fairseq expects to find train.json and valid.json during training
|
| 82 |
+
mv train_960.json train.json
|
| 83 |
+
|
| 84 |
+
echo "Prepare valid json"
|
| 85 |
+
python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
|
| 86 |
+
|
| 87 |
+
cp ${fairseq_dict} ./dict.txt
|
| 88 |
+
cp ${bpemodel}.model ./spm.model
|
fairseq-0.10.2/examples/speech_recognition/infer.py
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Run inference for pre-processed data with a trained model.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import math
|
| 13 |
+
import os
|
| 14 |
+
import sys
|
| 15 |
+
|
| 16 |
+
import editdistance
|
| 17 |
+
import numpy as np
|
| 18 |
+
import torch
|
| 19 |
+
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
|
| 20 |
+
from fairseq.data.data_utils import post_process
|
| 21 |
+
from fairseq.logging.meters import StopwatchMeter, TimeMeter
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logging.basicConfig()
|
| 25 |
+
logging.root.setLevel(logging.INFO)
|
| 26 |
+
logging.basicConfig(level=logging.INFO)
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def add_asr_eval_argument(parser):
|
| 31 |
+
parser.add_argument("--kspmodel", default=None, help="sentence piece model")
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--wfstlm", default=None, help="wfstlm on dictonary output units"
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--rnnt_decoding_type",
|
| 37 |
+
default="greedy",
|
| 38 |
+
help="wfstlm on dictonary\
|
| 39 |
+
output units",
|
| 40 |
+
)
|
| 41 |
+
try:
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--lm-weight",
|
| 44 |
+
"--lm_weight",
|
| 45 |
+
type=float,
|
| 46 |
+
default=0.2,
|
| 47 |
+
help="weight for lm while interpolating with neural score",
|
| 48 |
+
)
|
| 49 |
+
except:
|
| 50 |
+
pass
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
|
| 53 |
+
)
|
| 54 |
+
parser.add_argument(
|
| 55 |
+
"--w2l-decoder",
|
| 56 |
+
choices=["viterbi", "kenlm", "fairseqlm"],
|
| 57 |
+
help="use a w2l decoder",
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument("--lexicon", help="lexicon for w2l decoder")
|
| 60 |
+
parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
|
| 61 |
+
parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
|
| 62 |
+
parser.add_argument("--beam-threshold", type=float, default=25.0)
|
| 63 |
+
parser.add_argument("--beam-size-token", type=float, default=100)
|
| 64 |
+
parser.add_argument("--word-score", type=float, default=1.0)
|
| 65 |
+
parser.add_argument("--unk-weight", type=float, default=-math.inf)
|
| 66 |
+
parser.add_argument("--sil-weight", type=float, default=0.0)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--dump-emissions",
|
| 69 |
+
type=str,
|
| 70 |
+
default=None,
|
| 71 |
+
help="if present, dumps emissions into this file and exits",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument(
|
| 74 |
+
"--dump-features",
|
| 75 |
+
type=str,
|
| 76 |
+
default=None,
|
| 77 |
+
help="if present, dumps features into this file and exits",
|
| 78 |
+
)
|
| 79 |
+
parser.add_argument(
|
| 80 |
+
"--load-emissions",
|
| 81 |
+
type=str,
|
| 82 |
+
default=None,
|
| 83 |
+
help="if present, loads emissions from this file",
|
| 84 |
+
)
|
| 85 |
+
return parser
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def check_args(args):
|
| 89 |
+
# assert args.path is not None, "--path required for generation!"
|
| 90 |
+
# assert args.results_path is not None, "--results_path required for generation!"
|
| 91 |
+
assert (
|
| 92 |
+
not args.sampling or args.nbest == args.beam
|
| 93 |
+
), "--sampling requires --nbest to be equal to --beam"
|
| 94 |
+
assert (
|
| 95 |
+
args.replace_unk is None or args.raw_text
|
| 96 |
+
), "--replace-unk requires a raw text dataset (--raw-text)"
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_dataset_itr(args, task, models):
|
| 100 |
+
return task.get_batch_iterator(
|
| 101 |
+
dataset=task.dataset(args.gen_subset),
|
| 102 |
+
max_tokens=args.max_tokens,
|
| 103 |
+
max_sentences=args.batch_size,
|
| 104 |
+
max_positions=(sys.maxsize, sys.maxsize),
|
| 105 |
+
ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
|
| 106 |
+
required_batch_size_multiple=args.required_batch_size_multiple,
|
| 107 |
+
num_shards=args.num_shards,
|
| 108 |
+
shard_id=args.shard_id,
|
| 109 |
+
num_workers=args.num_workers,
|
| 110 |
+
data_buffer_size=args.data_buffer_size,
|
| 111 |
+
).next_epoch_itr(shuffle=False)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def process_predictions(
|
| 115 |
+
args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
|
| 116 |
+
):
|
| 117 |
+
for hypo in hypos[: min(len(hypos), args.nbest)]:
|
| 118 |
+
hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
|
| 119 |
+
|
| 120 |
+
if "words" in hypo:
|
| 121 |
+
hyp_words = " ".join(hypo["words"])
|
| 122 |
+
else:
|
| 123 |
+
hyp_words = post_process(hyp_pieces, args.remove_bpe)
|
| 124 |
+
|
| 125 |
+
if res_files is not None:
|
| 126 |
+
print(
|
| 127 |
+
"{} ({}-{})".format(hyp_pieces, speaker, id),
|
| 128 |
+
file=res_files["hypo.units"],
|
| 129 |
+
)
|
| 130 |
+
print(
|
| 131 |
+
"{} ({}-{})".format(hyp_words, speaker, id),
|
| 132 |
+
file=res_files["hypo.words"],
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
tgt_pieces = tgt_dict.string(target_tokens)
|
| 136 |
+
tgt_words = post_process(tgt_pieces, args.remove_bpe)
|
| 137 |
+
|
| 138 |
+
if res_files is not None:
|
| 139 |
+
print(
|
| 140 |
+
"{} ({}-{})".format(tgt_pieces, speaker, id),
|
| 141 |
+
file=res_files["ref.units"],
|
| 142 |
+
)
|
| 143 |
+
print(
|
| 144 |
+
"{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
|
| 145 |
+
)
|
| 146 |
+
# only score top hypothesis
|
| 147 |
+
if not args.quiet:
|
| 148 |
+
logger.debug("HYPO:" + hyp_words)
|
| 149 |
+
logger.debug("TARGET:" + tgt_words)
|
| 150 |
+
logger.debug("___________________")
|
| 151 |
+
|
| 152 |
+
hyp_words = hyp_words.split()
|
| 153 |
+
tgt_words = tgt_words.split()
|
| 154 |
+
return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def prepare_result_files(args):
|
| 158 |
+
def get_res_file(file_prefix):
|
| 159 |
+
if args.num_shards > 1:
|
| 160 |
+
file_prefix = f"{args.shard_id}_{file_prefix}"
|
| 161 |
+
path = os.path.join(
|
| 162 |
+
args.results_path,
|
| 163 |
+
"{}-{}-{}.txt".format(
|
| 164 |
+
file_prefix, os.path.basename(args.path), args.gen_subset
|
| 165 |
+
),
|
| 166 |
+
)
|
| 167 |
+
return open(path, "w", buffering=1)
|
| 168 |
+
|
| 169 |
+
if not args.results_path:
|
| 170 |
+
return None
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"hypo.words": get_res_file("hypo.word"),
|
| 174 |
+
"hypo.units": get_res_file("hypo.units"),
|
| 175 |
+
"ref.words": get_res_file("ref.word"),
|
| 176 |
+
"ref.units": get_res_file("ref.units"),
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def load_models_and_criterions(
|
| 181 |
+
filenames, data_path, arg_overrides=None, task=None, model_state=None
|
| 182 |
+
):
|
| 183 |
+
models = []
|
| 184 |
+
criterions = []
|
| 185 |
+
|
| 186 |
+
if arg_overrides is None:
|
| 187 |
+
arg_overrides = {}
|
| 188 |
+
|
| 189 |
+
arg_overrides["wer_args"] = None
|
| 190 |
+
arg_overrides["data"] = data_path
|
| 191 |
+
|
| 192 |
+
if filenames is None:
|
| 193 |
+
assert model_state is not None
|
| 194 |
+
filenames = [0]
|
| 195 |
+
else:
|
| 196 |
+
filenames = filenames.split(":")
|
| 197 |
+
|
| 198 |
+
for filename in filenames:
|
| 199 |
+
if model_state is None:
|
| 200 |
+
if not os.path.exists(filename):
|
| 201 |
+
raise IOError("Model file not found: {}".format(filename))
|
| 202 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
|
| 203 |
+
else:
|
| 204 |
+
state = model_state
|
| 205 |
+
|
| 206 |
+
args = state["args"]
|
| 207 |
+
if task is None:
|
| 208 |
+
task = tasks.setup_task(args)
|
| 209 |
+
model = task.build_model(args)
|
| 210 |
+
model.load_state_dict(state["model"], strict=True)
|
| 211 |
+
models.append(model)
|
| 212 |
+
|
| 213 |
+
criterion = task.build_criterion(args)
|
| 214 |
+
if "criterion" in state:
|
| 215 |
+
criterion.load_state_dict(state["criterion"], strict=True)
|
| 216 |
+
criterions.append(criterion)
|
| 217 |
+
return models, criterions, args
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def optimize_models(args, use_cuda, models):
|
| 221 |
+
"""Optimize ensemble for generation"""
|
| 222 |
+
for model in models:
|
| 223 |
+
model.make_generation_fast_(
|
| 224 |
+
beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
|
| 225 |
+
need_attn=args.print_alignment,
|
| 226 |
+
)
|
| 227 |
+
if args.fp16:
|
| 228 |
+
model.half()
|
| 229 |
+
if use_cuda:
|
| 230 |
+
model.cuda()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class ExistingEmissionsDecoder(object):
|
| 234 |
+
def __init__(self, decoder, emissions):
|
| 235 |
+
self.decoder = decoder
|
| 236 |
+
self.emissions = emissions
|
| 237 |
+
|
| 238 |
+
def generate(self, models, sample, **unused):
|
| 239 |
+
ids = sample["id"].cpu().numpy()
|
| 240 |
+
try:
|
| 241 |
+
emissions = np.stack(self.emissions[ids])
|
| 242 |
+
except:
|
| 243 |
+
print([x.shape for x in self.emissions[ids]])
|
| 244 |
+
raise Exception("invalid sizes")
|
| 245 |
+
emissions = torch.from_numpy(emissions)
|
| 246 |
+
return self.decoder.decode(emissions)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def main(args, task=None, model_state=None):
|
| 250 |
+
check_args(args)
|
| 251 |
+
|
| 252 |
+
if args.max_tokens is None and args.batch_size is None:
|
| 253 |
+
args.max_tokens = 4000000
|
| 254 |
+
logger.info(args)
|
| 255 |
+
|
| 256 |
+
use_cuda = torch.cuda.is_available() and not args.cpu
|
| 257 |
+
|
| 258 |
+
if task is None:
|
| 259 |
+
# Load dataset splits
|
| 260 |
+
task = tasks.setup_task(args)
|
| 261 |
+
task.load_dataset(args.gen_subset)
|
| 262 |
+
|
| 263 |
+
logger.info(
|
| 264 |
+
"| {} {} {} examples".format(
|
| 265 |
+
args.data, args.gen_subset, len(task.dataset(args.gen_subset))
|
| 266 |
+
)
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
# Set dictionary
|
| 270 |
+
tgt_dict = task.target_dictionary
|
| 271 |
+
|
| 272 |
+
logger.info("| decoding with criterion {}".format(args.criterion))
|
| 273 |
+
|
| 274 |
+
# Load ensemble
|
| 275 |
+
|
| 276 |
+
if args.load_emissions:
|
| 277 |
+
models, criterions = [], []
|
| 278 |
+
else:
|
| 279 |
+
logger.info("| loading model(s) from {}".format(args.path))
|
| 280 |
+
models, criterions, _ = load_models_and_criterions(
|
| 281 |
+
args.path,
|
| 282 |
+
data_path=args.data,
|
| 283 |
+
arg_overrides=eval(args.model_overrides), # noqa
|
| 284 |
+
task=task,
|
| 285 |
+
model_state=model_state,
|
| 286 |
+
)
|
| 287 |
+
optimize_models(args, use_cuda, models)
|
| 288 |
+
|
| 289 |
+
# hack to pass transitions to W2lDecoder
|
| 290 |
+
if args.criterion == "asg_loss":
|
| 291 |
+
trans = criterions[0].asg.trans.data
|
| 292 |
+
args.asg_transitions = torch.flatten(trans).tolist()
|
| 293 |
+
|
| 294 |
+
# Load dataset (possibly sharded)
|
| 295 |
+
itr = get_dataset_itr(args, task, models)
|
| 296 |
+
|
| 297 |
+
# Initialize generator
|
| 298 |
+
gen_timer = StopwatchMeter()
|
| 299 |
+
|
| 300 |
+
def build_generator(args):
|
| 301 |
+
w2l_decoder = getattr(args, "w2l_decoder", None)
|
| 302 |
+
if w2l_decoder == "viterbi":
|
| 303 |
+
from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
|
| 304 |
+
|
| 305 |
+
return W2lViterbiDecoder(args, task.target_dictionary)
|
| 306 |
+
elif w2l_decoder == "kenlm":
|
| 307 |
+
from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
|
| 308 |
+
|
| 309 |
+
return W2lKenLMDecoder(args, task.target_dictionary)
|
| 310 |
+
elif w2l_decoder == "fairseqlm":
|
| 311 |
+
from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
|
| 312 |
+
|
| 313 |
+
return W2lFairseqLMDecoder(args, task.target_dictionary)
|
| 314 |
+
else:
|
| 315 |
+
print(
|
| 316 |
+
"only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
# please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
|
| 320 |
+
generator = build_generator(args)
|
| 321 |
+
|
| 322 |
+
if args.load_emissions:
|
| 323 |
+
generator = ExistingEmissionsDecoder(
|
| 324 |
+
generator, np.load(args.load_emissions, allow_pickle=True)
|
| 325 |
+
)
|
| 326 |
+
logger.info("loaded emissions from " + args.load_emissions)
|
| 327 |
+
|
| 328 |
+
num_sentences = 0
|
| 329 |
+
|
| 330 |
+
if args.results_path is not None and not os.path.exists(args.results_path):
|
| 331 |
+
os.makedirs(args.results_path)
|
| 332 |
+
|
| 333 |
+
max_source_pos = (
|
| 334 |
+
utils.resolve_max_positions(
|
| 335 |
+
task.max_positions(), *[model.max_positions() for model in models]
|
| 336 |
+
),
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if max_source_pos is not None:
|
| 340 |
+
max_source_pos = max_source_pos[0]
|
| 341 |
+
if max_source_pos is not None:
|
| 342 |
+
max_source_pos = max_source_pos[0] - 1
|
| 343 |
+
|
| 344 |
+
if args.dump_emissions:
|
| 345 |
+
emissions = {}
|
| 346 |
+
if args.dump_features:
|
| 347 |
+
features = {}
|
| 348 |
+
models[0].bert.proj = None
|
| 349 |
+
else:
|
| 350 |
+
res_files = prepare_result_files(args)
|
| 351 |
+
errs_t = 0
|
| 352 |
+
lengths_t = 0
|
| 353 |
+
with progress_bar.build_progress_bar(args, itr) as t:
|
| 354 |
+
wps_meter = TimeMeter()
|
| 355 |
+
for sample in t:
|
| 356 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
| 357 |
+
if "net_input" not in sample:
|
| 358 |
+
continue
|
| 359 |
+
|
| 360 |
+
prefix_tokens = None
|
| 361 |
+
if args.prefix_size > 0:
|
| 362 |
+
prefix_tokens = sample["target"][:, : args.prefix_size]
|
| 363 |
+
|
| 364 |
+
gen_timer.start()
|
| 365 |
+
if args.dump_emissions:
|
| 366 |
+
with torch.no_grad():
|
| 367 |
+
encoder_out = models[0](**sample["net_input"])
|
| 368 |
+
emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
|
| 369 |
+
emm = emm.transpose(0, 1).cpu().numpy()
|
| 370 |
+
for i, id in enumerate(sample["id"]):
|
| 371 |
+
emissions[id.item()] = emm[i]
|
| 372 |
+
continue
|
| 373 |
+
elif args.dump_features:
|
| 374 |
+
with torch.no_grad():
|
| 375 |
+
encoder_out = models[0](**sample["net_input"])
|
| 376 |
+
feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
|
| 377 |
+
for i, id in enumerate(sample["id"]):
|
| 378 |
+
padding = (
|
| 379 |
+
encoder_out["encoder_padding_mask"][i].cpu().numpy()
|
| 380 |
+
if encoder_out["encoder_padding_mask"] is not None
|
| 381 |
+
else None
|
| 382 |
+
)
|
| 383 |
+
features[id.item()] = (feat[i], padding)
|
| 384 |
+
continue
|
| 385 |
+
hypos = task.inference_step(generator, models, sample, prefix_tokens)
|
| 386 |
+
num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
|
| 387 |
+
gen_timer.stop(num_generated_tokens)
|
| 388 |
+
|
| 389 |
+
for i, sample_id in enumerate(sample["id"].tolist()):
|
| 390 |
+
speaker = None
|
| 391 |
+
# id = task.dataset(args.gen_subset).ids[int(sample_id)]
|
| 392 |
+
id = sample_id
|
| 393 |
+
toks = (
|
| 394 |
+
sample["target"][i, :]
|
| 395 |
+
if "target_label" not in sample
|
| 396 |
+
else sample["target_label"][i, :]
|
| 397 |
+
)
|
| 398 |
+
target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
|
| 399 |
+
# Process top predictions
|
| 400 |
+
errs, length = process_predictions(
|
| 401 |
+
args,
|
| 402 |
+
hypos[i],
|
| 403 |
+
None,
|
| 404 |
+
tgt_dict,
|
| 405 |
+
target_tokens,
|
| 406 |
+
res_files,
|
| 407 |
+
speaker,
|
| 408 |
+
id,
|
| 409 |
+
)
|
| 410 |
+
errs_t += errs
|
| 411 |
+
lengths_t += length
|
| 412 |
+
|
| 413 |
+
wps_meter.update(num_generated_tokens)
|
| 414 |
+
t.log({"wps": round(wps_meter.avg)})
|
| 415 |
+
num_sentences += (
|
| 416 |
+
sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
wer = None
|
| 420 |
+
if args.dump_emissions:
|
| 421 |
+
emm_arr = []
|
| 422 |
+
for i in range(len(emissions)):
|
| 423 |
+
emm_arr.append(emissions[i])
|
| 424 |
+
np.save(args.dump_emissions, emm_arr)
|
| 425 |
+
logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
|
| 426 |
+
elif args.dump_features:
|
| 427 |
+
feat_arr = []
|
| 428 |
+
for i in range(len(features)):
|
| 429 |
+
feat_arr.append(features[i])
|
| 430 |
+
np.save(args.dump_features, feat_arr)
|
| 431 |
+
logger.info(f"saved {len(features)} emissions to {args.dump_features}")
|
| 432 |
+
else:
|
| 433 |
+
if lengths_t > 0:
|
| 434 |
+
wer = errs_t * 100.0 / lengths_t
|
| 435 |
+
logger.info(f"WER: {wer}")
|
| 436 |
+
|
| 437 |
+
logger.info(
|
| 438 |
+
"| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
|
| 439 |
+
"sentences/s, {:.2f} tokens/s)".format(
|
| 440 |
+
num_sentences,
|
| 441 |
+
gen_timer.n,
|
| 442 |
+
gen_timer.sum,
|
| 443 |
+
num_sentences / gen_timer.sum,
|
| 444 |
+
1.0 / gen_timer.avg,
|
| 445 |
+
)
|
| 446 |
+
)
|
| 447 |
+
logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
|
| 448 |
+
return task, wer
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def make_parser():
|
| 452 |
+
parser = options.get_generation_parser()
|
| 453 |
+
parser = add_asr_eval_argument(parser)
|
| 454 |
+
return parser
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def cli_main():
|
| 458 |
+
parser = make_parser()
|
| 459 |
+
args = options.parse_args_and_arch(parser)
|
| 460 |
+
main(args)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
if __name__ == "__main__":
|
| 464 |
+
cli_main()
|
fairseq-0.10.2/examples/speech_recognition/utils/wer_utils.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the MIT license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from collections import deque
|
| 12 |
+
from enum import Enum
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Utility modules for computation of Word Error Rate,
|
| 19 |
+
Alignments, as well as more granular metrics like
|
| 20 |
+
deletion, insersion and substitutions.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Code(Enum):
|
| 25 |
+
match = 1
|
| 26 |
+
substitution = 2
|
| 27 |
+
insertion = 3
|
| 28 |
+
deletion = 4
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Token(object):
|
| 32 |
+
def __init__(self, lbl="", st=np.nan, en=np.nan):
|
| 33 |
+
if np.isnan(st):
|
| 34 |
+
self.label, self.start, self.end = "", 0.0, 0.0
|
| 35 |
+
else:
|
| 36 |
+
self.label, self.start, self.end = lbl, st, en
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class AlignmentResult(object):
|
| 40 |
+
def __init__(self, refs, hyps, codes, score):
|
| 41 |
+
self.refs = refs # std::deque<int>
|
| 42 |
+
self.hyps = hyps # std::deque<int>
|
| 43 |
+
self.codes = codes # std::deque<Code>
|
| 44 |
+
self.score = score # float
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def coordinate_to_offset(row, col, ncols):
|
| 48 |
+
return int(row * ncols + col)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def offset_to_row(offset, ncols):
|
| 52 |
+
return int(offset / ncols)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def offset_to_col(offset, ncols):
|
| 56 |
+
return int(offset % ncols)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def trimWhitespace(str):
|
| 60 |
+
return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def str2toks(str):
|
| 64 |
+
pieces = trimWhitespace(str).split(" ")
|
| 65 |
+
toks = []
|
| 66 |
+
for p in pieces:
|
| 67 |
+
toks.append(Token(p, 0.0, 0.0))
|
| 68 |
+
return toks
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class EditDistance(object):
|
| 72 |
+
def __init__(self, time_mediated):
|
| 73 |
+
self.time_mediated_ = time_mediated
|
| 74 |
+
self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
|
| 75 |
+
self.backtraces_ = (
|
| 76 |
+
np.nan
|
| 77 |
+
) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
|
| 78 |
+
self.confusion_pairs_ = {}
|
| 79 |
+
|
| 80 |
+
def cost(self, ref, hyp, code):
|
| 81 |
+
if self.time_mediated_:
|
| 82 |
+
if code == Code.match:
|
| 83 |
+
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
|
| 84 |
+
elif code == Code.insertion:
|
| 85 |
+
return hyp.end - hyp.start
|
| 86 |
+
elif code == Code.deletion:
|
| 87 |
+
return ref.end - ref.start
|
| 88 |
+
else: # substitution
|
| 89 |
+
return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
|
| 90 |
+
else:
|
| 91 |
+
if code == Code.match:
|
| 92 |
+
return 0
|
| 93 |
+
elif code == Code.insertion or code == Code.deletion:
|
| 94 |
+
return 3
|
| 95 |
+
else: # substitution
|
| 96 |
+
return 4
|
| 97 |
+
|
| 98 |
+
def get_result(self, refs, hyps):
|
| 99 |
+
res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
|
| 100 |
+
|
| 101 |
+
num_rows, num_cols = self.scores_.shape
|
| 102 |
+
res.score = self.scores_[num_rows - 1, num_cols - 1]
|
| 103 |
+
|
| 104 |
+
curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
|
| 105 |
+
|
| 106 |
+
while curr_offset != 0:
|
| 107 |
+
curr_row = offset_to_row(curr_offset, num_cols)
|
| 108 |
+
curr_col = offset_to_col(curr_offset, num_cols)
|
| 109 |
+
|
| 110 |
+
prev_offset = self.backtraces_[curr_row, curr_col]
|
| 111 |
+
|
| 112 |
+
prev_row = offset_to_row(prev_offset, num_cols)
|
| 113 |
+
prev_col = offset_to_col(prev_offset, num_cols)
|
| 114 |
+
|
| 115 |
+
res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
|
| 116 |
+
res.hyps.appendleft(curr_col - 1)
|
| 117 |
+
if curr_row - 1 == prev_row and curr_col == prev_col:
|
| 118 |
+
res.codes.appendleft(Code.deletion)
|
| 119 |
+
elif curr_row == prev_row and curr_col - 1 == prev_col:
|
| 120 |
+
res.codes.appendleft(Code.insertion)
|
| 121 |
+
else:
|
| 122 |
+
# assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
|
| 123 |
+
ref_str = refs[res.refs[0]].label
|
| 124 |
+
hyp_str = hyps[res.hyps[0]].label
|
| 125 |
+
|
| 126 |
+
if ref_str == hyp_str:
|
| 127 |
+
res.codes.appendleft(Code.match)
|
| 128 |
+
else:
|
| 129 |
+
res.codes.appendleft(Code.substitution)
|
| 130 |
+
|
| 131 |
+
confusion_pair = "%s -> %s" % (ref_str, hyp_str)
|
| 132 |
+
if confusion_pair not in self.confusion_pairs_:
|
| 133 |
+
self.confusion_pairs_[confusion_pair] = 1
|
| 134 |
+
else:
|
| 135 |
+
self.confusion_pairs_[confusion_pair] += 1
|
| 136 |
+
|
| 137 |
+
curr_offset = prev_offset
|
| 138 |
+
|
| 139 |
+
return res
|
| 140 |
+
|
| 141 |
+
def align(self, refs, hyps):
|
| 142 |
+
if len(refs) == 0 and len(hyps) == 0:
|
| 143 |
+
return np.nan
|
| 144 |
+
|
| 145 |
+
# NOTE: we're not resetting the values in these matrices because every value
|
| 146 |
+
# will be overridden in the loop below. If this assumption doesn't hold,
|
| 147 |
+
# be sure to set all entries in self.scores_ and self.backtraces_ to 0.
|
| 148 |
+
self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
|
| 149 |
+
self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
|
| 150 |
+
|
| 151 |
+
num_rows, num_cols = self.scores_.shape
|
| 152 |
+
|
| 153 |
+
for i in range(num_rows):
|
| 154 |
+
for j in range(num_cols):
|
| 155 |
+
if i == 0 and j == 0:
|
| 156 |
+
self.scores_[i, j] = 0.0
|
| 157 |
+
self.backtraces_[i, j] = 0
|
| 158 |
+
continue
|
| 159 |
+
|
| 160 |
+
if i == 0:
|
| 161 |
+
self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
|
| 162 |
+
None, hyps[j - 1], Code.insertion
|
| 163 |
+
)
|
| 164 |
+
self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
|
| 165 |
+
continue
|
| 166 |
+
|
| 167 |
+
if j == 0:
|
| 168 |
+
self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
|
| 169 |
+
refs[i - 1], None, Code.deletion
|
| 170 |
+
)
|
| 171 |
+
self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
|
| 172 |
+
continue
|
| 173 |
+
|
| 174 |
+
# Below here both i and j are greater than 0
|
| 175 |
+
ref = refs[i - 1]
|
| 176 |
+
hyp = hyps[j - 1]
|
| 177 |
+
best_score = self.scores_[i - 1, j - 1] + (
|
| 178 |
+
self.cost(ref, hyp, Code.match)
|
| 179 |
+
if (ref.label == hyp.label)
|
| 180 |
+
else self.cost(ref, hyp, Code.substitution)
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
prev_row = i - 1
|
| 184 |
+
prev_col = j - 1
|
| 185 |
+
ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
|
| 186 |
+
if ins < best_score:
|
| 187 |
+
best_score = ins
|
| 188 |
+
prev_row = i
|
| 189 |
+
prev_col = j - 1
|
| 190 |
+
|
| 191 |
+
delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
|
| 192 |
+
if delt < best_score:
|
| 193 |
+
best_score = delt
|
| 194 |
+
prev_row = i - 1
|
| 195 |
+
prev_col = j
|
| 196 |
+
|
| 197 |
+
self.scores_[i, j] = best_score
|
| 198 |
+
self.backtraces_[i, j] = coordinate_to_offset(
|
| 199 |
+
prev_row, prev_col, num_cols
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
return self.get_result(refs, hyps)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class WERTransformer(object):
|
| 206 |
+
def __init__(self, hyp_str, ref_str, verbose=True):
|
| 207 |
+
self.ed_ = EditDistance(False)
|
| 208 |
+
self.id2oracle_errs_ = {}
|
| 209 |
+
self.utts_ = 0
|
| 210 |
+
self.words_ = 0
|
| 211 |
+
self.insertions_ = 0
|
| 212 |
+
self.deletions_ = 0
|
| 213 |
+
self.substitutions_ = 0
|
| 214 |
+
|
| 215 |
+
self.process(["dummy_str", hyp_str, ref_str])
|
| 216 |
+
|
| 217 |
+
if verbose:
|
| 218 |
+
print("'%s' vs '%s'" % (hyp_str, ref_str))
|
| 219 |
+
self.report_result()
|
| 220 |
+
|
| 221 |
+
def process(self, input): # std::vector<std::string>&& input
|
| 222 |
+
if len(input) < 3:
|
| 223 |
+
print(
|
| 224 |
+
"Input must be of the form <id> ... <hypo> <ref> , got ",
|
| 225 |
+
len(input),
|
| 226 |
+
" inputs:",
|
| 227 |
+
)
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
# Align
|
| 231 |
+
# std::vector<Token> hyps;
|
| 232 |
+
# std::vector<Token> refs;
|
| 233 |
+
|
| 234 |
+
hyps = str2toks(input[-2])
|
| 235 |
+
refs = str2toks(input[-1])
|
| 236 |
+
|
| 237 |
+
alignment = self.ed_.align(refs, hyps)
|
| 238 |
+
if alignment is None:
|
| 239 |
+
print("Alignment is null")
|
| 240 |
+
return np.nan
|
| 241 |
+
|
| 242 |
+
# Tally errors
|
| 243 |
+
ins = 0
|
| 244 |
+
dels = 0
|
| 245 |
+
subs = 0
|
| 246 |
+
for code in alignment.codes:
|
| 247 |
+
if code == Code.substitution:
|
| 248 |
+
subs += 1
|
| 249 |
+
elif code == Code.insertion:
|
| 250 |
+
ins += 1
|
| 251 |
+
elif code == Code.deletion:
|
| 252 |
+
dels += 1
|
| 253 |
+
|
| 254 |
+
# Output
|
| 255 |
+
row = input
|
| 256 |
+
row.append(str(len(refs)))
|
| 257 |
+
row.append(str(ins))
|
| 258 |
+
row.append(str(dels))
|
| 259 |
+
row.append(str(subs))
|
| 260 |
+
# print(row)
|
| 261 |
+
|
| 262 |
+
# Accumulate
|
| 263 |
+
kIdIndex = 0
|
| 264 |
+
kNBestSep = "/"
|
| 265 |
+
|
| 266 |
+
pieces = input[kIdIndex].split(kNBestSep)
|
| 267 |
+
|
| 268 |
+
if len(pieces) == 0:
|
| 269 |
+
print(
|
| 270 |
+
"Error splitting ",
|
| 271 |
+
input[kIdIndex],
|
| 272 |
+
" on '",
|
| 273 |
+
kNBestSep,
|
| 274 |
+
"', got empty list",
|
| 275 |
+
)
|
| 276 |
+
return np.nan
|
| 277 |
+
|
| 278 |
+
id = pieces[0]
|
| 279 |
+
if id not in self.id2oracle_errs_:
|
| 280 |
+
self.utts_ += 1
|
| 281 |
+
self.words_ += len(refs)
|
| 282 |
+
self.insertions_ += ins
|
| 283 |
+
self.deletions_ += dels
|
| 284 |
+
self.substitutions_ += subs
|
| 285 |
+
self.id2oracle_errs_[id] = [ins, dels, subs]
|
| 286 |
+
else:
|
| 287 |
+
curr_err = ins + dels + subs
|
| 288 |
+
prev_err = np.sum(self.id2oracle_errs_[id])
|
| 289 |
+
if curr_err < prev_err:
|
| 290 |
+
self.id2oracle_errs_[id] = [ins, dels, subs]
|
| 291 |
+
|
| 292 |
+
return 0
|
| 293 |
+
|
| 294 |
+
def report_result(self):
|
| 295 |
+
# print("---------- Summary ---------------")
|
| 296 |
+
if self.words_ == 0:
|
| 297 |
+
print("No words counted")
|
| 298 |
+
return
|
| 299 |
+
|
| 300 |
+
# 1-best
|
| 301 |
+
best_wer = (
|
| 302 |
+
100.0
|
| 303 |
+
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
| 304 |
+
/ self.words_
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
print(
|
| 308 |
+
"\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
|
| 309 |
+
"%0.2f%% dels, %0.2f%% subs)"
|
| 310 |
+
% (
|
| 311 |
+
best_wer,
|
| 312 |
+
self.utts_,
|
| 313 |
+
self.words_,
|
| 314 |
+
100.0 * self.insertions_ / self.words_,
|
| 315 |
+
100.0 * self.deletions_ / self.words_,
|
| 316 |
+
100.0 * self.substitutions_ / self.words_,
|
| 317 |
+
)
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
def wer(self):
|
| 321 |
+
if self.words_ == 0:
|
| 322 |
+
wer = np.nan
|
| 323 |
+
else:
|
| 324 |
+
wer = (
|
| 325 |
+
100.0
|
| 326 |
+
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
| 327 |
+
/ self.words_
|
| 328 |
+
)
|
| 329 |
+
return wer
|
| 330 |
+
|
| 331 |
+
def stats(self):
|
| 332 |
+
if self.words_ == 0:
|
| 333 |
+
stats = {}
|
| 334 |
+
else:
|
| 335 |
+
wer = (
|
| 336 |
+
100.0
|
| 337 |
+
* (self.insertions_ + self.deletions_ + self.substitutions_)
|
| 338 |
+
/ self.words_
|
| 339 |
+
)
|
| 340 |
+
stats = dict(
|
| 341 |
+
{
|
| 342 |
+
"wer": wer,
|
| 343 |
+
"utts": self.utts_,
|
| 344 |
+
"numwords": self.words_,
|
| 345 |
+
"ins": self.insertions_,
|
| 346 |
+
"dels": self.deletions_,
|
| 347 |
+
"subs": self.substitutions_,
|
| 348 |
+
"confusion_pairs": self.ed_.confusion_pairs_,
|
| 349 |
+
}
|
| 350 |
+
)
|
| 351 |
+
return stats
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def calc_wer(hyp_str, ref_str):
|
| 355 |
+
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
| 356 |
+
return t.wer()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def calc_wer_stats(hyp_str, ref_str):
|
| 360 |
+
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
| 361 |
+
return t.stats()
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def get_wer_alignment_codes(hyp_str, ref_str):
|
| 365 |
+
"""
|
| 366 |
+
INPUT: hypothesis string, reference string
|
| 367 |
+
OUTPUT: List of alignment codes (intermediate results from WER computation)
|
| 368 |
+
"""
|
| 369 |
+
t = WERTransformer(hyp_str, ref_str, verbose=0)
|
| 370 |
+
return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def merge_counts(x, y):
|
| 374 |
+
# Merge two hashes which have 'counts' as their values
|
| 375 |
+
# This can be used for example to merge confusion pair counts
|
| 376 |
+
# conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
|
| 377 |
+
for k, v in y.items():
|
| 378 |
+
if k not in x:
|
| 379 |
+
x[k] = 0
|
| 380 |
+
x[k] += v
|
| 381 |
+
return x
|
fairseq-0.10.2/examples/speech_recognition/w2l_decoder.py
ADDED
|
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the MIT license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
Wav2letter decoders.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import gc
|
| 13 |
+
import itertools as it
|
| 14 |
+
import os.path as osp
|
| 15 |
+
import warnings
|
| 16 |
+
from collections import deque, namedtuple
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from examples.speech_recognition.data.replabels import unpack_replabels
|
| 21 |
+
from fairseq import tasks
|
| 22 |
+
from fairseq.utils import apply_to_sample
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from wav2letter.common import create_word_dict, load_words
|
| 27 |
+
from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
|
| 28 |
+
from wav2letter.decoder import (
|
| 29 |
+
CriterionType,
|
| 30 |
+
DecoderOptions,
|
| 31 |
+
KenLM,
|
| 32 |
+
LM,
|
| 33 |
+
LMState,
|
| 34 |
+
SmearingMode,
|
| 35 |
+
Trie,
|
| 36 |
+
LexiconDecoder,
|
| 37 |
+
LexiconFreeDecoder,
|
| 38 |
+
)
|
| 39 |
+
except:
|
| 40 |
+
warnings.warn(
|
| 41 |
+
"wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
|
| 42 |
+
)
|
| 43 |
+
LM = object
|
| 44 |
+
LMState = object
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class W2lDecoder(object):
|
| 48 |
+
def __init__(self, args, tgt_dict):
|
| 49 |
+
self.tgt_dict = tgt_dict
|
| 50 |
+
self.vocab_size = len(tgt_dict)
|
| 51 |
+
self.nbest = args.nbest
|
| 52 |
+
|
| 53 |
+
# criterion-specific init
|
| 54 |
+
if args.criterion == "ctc":
|
| 55 |
+
self.criterion_type = CriterionType.CTC
|
| 56 |
+
self.blank = (
|
| 57 |
+
tgt_dict.index("<ctc_blank>")
|
| 58 |
+
if "<ctc_blank>" in tgt_dict.indices
|
| 59 |
+
else tgt_dict.bos()
|
| 60 |
+
)
|
| 61 |
+
self.asg_transitions = None
|
| 62 |
+
elif args.criterion == "asg_loss":
|
| 63 |
+
self.criterion_type = CriterionType.ASG
|
| 64 |
+
self.blank = -1
|
| 65 |
+
self.asg_transitions = args.asg_transitions
|
| 66 |
+
self.max_replabel = args.max_replabel
|
| 67 |
+
assert len(self.asg_transitions) == self.vocab_size ** 2
|
| 68 |
+
else:
|
| 69 |
+
raise RuntimeError(f"unknown criterion: {args.criterion}")
|
| 70 |
+
|
| 71 |
+
def generate(self, models, sample, **unused):
|
| 72 |
+
"""Generate a batch of inferences."""
|
| 73 |
+
# model.forward normally channels prev_output_tokens into the decoder
|
| 74 |
+
# separately, but SequenceGenerator directly calls model.encoder
|
| 75 |
+
encoder_input = {
|
| 76 |
+
k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
|
| 77 |
+
}
|
| 78 |
+
emissions = self.get_emissions(models, encoder_input)
|
| 79 |
+
return self.decode(emissions)
|
| 80 |
+
|
| 81 |
+
def get_emissions(self, models, encoder_input):
|
| 82 |
+
"""Run encoder and normalize emissions"""
|
| 83 |
+
# encoder_out = models[0].encoder(**encoder_input)
|
| 84 |
+
encoder_out = models[0](**encoder_input)
|
| 85 |
+
if self.criterion_type == CriterionType.CTC:
|
| 86 |
+
emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
|
| 87 |
+
elif self.criterion_type == CriterionType.ASG:
|
| 88 |
+
emissions = encoder_out["encoder_out"]
|
| 89 |
+
return emissions.transpose(0, 1).float().cpu().contiguous()
|
| 90 |
+
|
| 91 |
+
def get_tokens(self, idxs):
|
| 92 |
+
"""Normalize tokens by handling CTC blank, ASG replabels, etc."""
|
| 93 |
+
idxs = (g[0] for g in it.groupby(idxs))
|
| 94 |
+
if self.criterion_type == CriterionType.CTC:
|
| 95 |
+
idxs = filter(lambda x: x != self.blank, idxs)
|
| 96 |
+
elif self.criterion_type == CriterionType.ASG:
|
| 97 |
+
idxs = filter(lambda x: x >= 0, idxs)
|
| 98 |
+
idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
|
| 99 |
+
return torch.LongTensor(list(idxs))
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class W2lViterbiDecoder(W2lDecoder):
|
| 103 |
+
def __init__(self, args, tgt_dict):
|
| 104 |
+
super().__init__(args, tgt_dict)
|
| 105 |
+
|
| 106 |
+
def decode(self, emissions):
|
| 107 |
+
B, T, N = emissions.size()
|
| 108 |
+
hypos = []
|
| 109 |
+
if self.asg_transitions is None:
|
| 110 |
+
transitions = torch.FloatTensor(N, N).zero_()
|
| 111 |
+
else:
|
| 112 |
+
transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
|
| 113 |
+
viterbi_path = torch.IntTensor(B, T)
|
| 114 |
+
workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
|
| 115 |
+
CpuViterbiPath.compute(
|
| 116 |
+
B,
|
| 117 |
+
T,
|
| 118 |
+
N,
|
| 119 |
+
get_data_ptr_as_bytes(emissions),
|
| 120 |
+
get_data_ptr_as_bytes(transitions),
|
| 121 |
+
get_data_ptr_as_bytes(viterbi_path),
|
| 122 |
+
get_data_ptr_as_bytes(workspace),
|
| 123 |
+
)
|
| 124 |
+
return [
|
| 125 |
+
[{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
|
| 126 |
+
for b in range(B)
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class W2lKenLMDecoder(W2lDecoder):
|
| 131 |
+
def __init__(self, args, tgt_dict):
|
| 132 |
+
super().__init__(args, tgt_dict)
|
| 133 |
+
|
| 134 |
+
self.silence = (
|
| 135 |
+
tgt_dict.index("<ctc_blank>")
|
| 136 |
+
if "<ctc_blank>" in tgt_dict.indices
|
| 137 |
+
else tgt_dict.bos()
|
| 138 |
+
)
|
| 139 |
+
self.lexicon = load_words(args.lexicon)
|
| 140 |
+
self.word_dict = create_word_dict(self.lexicon)
|
| 141 |
+
self.unk_word = self.word_dict.get_index("<unk>")
|
| 142 |
+
|
| 143 |
+
self.lm = KenLM(args.kenlm_model, self.word_dict)
|
| 144 |
+
self.trie = Trie(self.vocab_size, self.silence)
|
| 145 |
+
|
| 146 |
+
start_state = self.lm.start(False)
|
| 147 |
+
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
| 148 |
+
word_idx = self.word_dict.get_index(word)
|
| 149 |
+
_, score = self.lm.score(start_state, word_idx)
|
| 150 |
+
for spelling in spellings:
|
| 151 |
+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
| 152 |
+
assert (
|
| 153 |
+
tgt_dict.unk() not in spelling_idxs
|
| 154 |
+
), f"{spelling} {spelling_idxs}"
|
| 155 |
+
self.trie.insert(spelling_idxs, word_idx, score)
|
| 156 |
+
self.trie.smear(SmearingMode.MAX)
|
| 157 |
+
|
| 158 |
+
self.decoder_opts = DecoderOptions(
|
| 159 |
+
args.beam,
|
| 160 |
+
int(getattr(args, "beam_size_token", len(tgt_dict))),
|
| 161 |
+
args.beam_threshold,
|
| 162 |
+
args.lm_weight,
|
| 163 |
+
args.word_score,
|
| 164 |
+
args.unk_weight,
|
| 165 |
+
args.sil_weight,
|
| 166 |
+
0,
|
| 167 |
+
False,
|
| 168 |
+
self.criterion_type,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if self.asg_transitions is None:
|
| 172 |
+
N = 768
|
| 173 |
+
# self.asg_transitions = torch.FloatTensor(N, N).zero_()
|
| 174 |
+
self.asg_transitions = []
|
| 175 |
+
|
| 176 |
+
self.decoder = LexiconDecoder(
|
| 177 |
+
self.decoder_opts,
|
| 178 |
+
self.trie,
|
| 179 |
+
self.lm,
|
| 180 |
+
self.silence,
|
| 181 |
+
self.blank,
|
| 182 |
+
self.unk_word,
|
| 183 |
+
self.asg_transitions,
|
| 184 |
+
False,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def decode(self, emissions):
|
| 188 |
+
B, T, N = emissions.size()
|
| 189 |
+
hypos = []
|
| 190 |
+
for b in range(B):
|
| 191 |
+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
| 192 |
+
results = self.decoder.decode(emissions_ptr, T, N)
|
| 193 |
+
|
| 194 |
+
nbest_results = results[: self.nbest]
|
| 195 |
+
hypos.append(
|
| 196 |
+
[
|
| 197 |
+
{
|
| 198 |
+
"tokens": self.get_tokens(result.tokens),
|
| 199 |
+
"score": result.score,
|
| 200 |
+
"words": [
|
| 201 |
+
self.word_dict.get_entry(x) for x in result.words if x >= 0
|
| 202 |
+
],
|
| 203 |
+
}
|
| 204 |
+
for result in nbest_results
|
| 205 |
+
]
|
| 206 |
+
)
|
| 207 |
+
return hypos
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class FairseqLM(LM):
|
| 214 |
+
def __init__(self, dictionary, model):
|
| 215 |
+
LM.__init__(self)
|
| 216 |
+
self.dictionary = dictionary
|
| 217 |
+
self.model = model
|
| 218 |
+
self.unk = self.dictionary.unk()
|
| 219 |
+
|
| 220 |
+
self.save_incremental = False # this currently does not work properly
|
| 221 |
+
self.max_cache = 20_000
|
| 222 |
+
|
| 223 |
+
model.cuda()
|
| 224 |
+
model.eval()
|
| 225 |
+
model.make_generation_fast_()
|
| 226 |
+
|
| 227 |
+
self.states = {}
|
| 228 |
+
self.stateq = deque()
|
| 229 |
+
|
| 230 |
+
def start(self, start_with_nothing):
|
| 231 |
+
state = LMState()
|
| 232 |
+
prefix = torch.LongTensor([[self.dictionary.eos()]])
|
| 233 |
+
incremental_state = {} if self.save_incremental else None
|
| 234 |
+
with torch.no_grad():
|
| 235 |
+
res = self.model(prefix.cuda(), incremental_state=incremental_state)
|
| 236 |
+
probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
|
| 237 |
+
|
| 238 |
+
if incremental_state is not None:
|
| 239 |
+
incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
|
| 240 |
+
self.states[state] = FairseqLMState(
|
| 241 |
+
prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
|
| 242 |
+
)
|
| 243 |
+
self.stateq.append(state)
|
| 244 |
+
|
| 245 |
+
return state
|
| 246 |
+
|
| 247 |
+
def score(self, state: LMState, token_index: int, no_cache: bool = False):
|
| 248 |
+
"""
|
| 249 |
+
Evaluate language model based on the current lm state and new word
|
| 250 |
+
Parameters:
|
| 251 |
+
-----------
|
| 252 |
+
state: current lm state
|
| 253 |
+
token_index: index of the word
|
| 254 |
+
(can be lexicon index then you should store inside LM the
|
| 255 |
+
mapping between indices of lexicon and lm, or lm index of a word)
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
--------
|
| 259 |
+
(LMState, float): pair of (new state, score for the current word)
|
| 260 |
+
"""
|
| 261 |
+
curr_state = self.states[state]
|
| 262 |
+
|
| 263 |
+
def trim_cache(targ_size):
|
| 264 |
+
while len(self.stateq) > targ_size:
|
| 265 |
+
rem_k = self.stateq.popleft()
|
| 266 |
+
rem_st = self.states[rem_k]
|
| 267 |
+
rem_st = FairseqLMState(rem_st.prefix, None, None)
|
| 268 |
+
self.states[rem_k] = rem_st
|
| 269 |
+
|
| 270 |
+
if curr_state.probs is None:
|
| 271 |
+
new_incremental_state = (
|
| 272 |
+
curr_state.incremental_state.copy()
|
| 273 |
+
if curr_state.incremental_state is not None
|
| 274 |
+
else None
|
| 275 |
+
)
|
| 276 |
+
with torch.no_grad():
|
| 277 |
+
if new_incremental_state is not None:
|
| 278 |
+
new_incremental_state = apply_to_sample(
|
| 279 |
+
lambda x: x.cuda(), new_incremental_state
|
| 280 |
+
)
|
| 281 |
+
elif self.save_incremental:
|
| 282 |
+
new_incremental_state = {}
|
| 283 |
+
|
| 284 |
+
res = self.model(
|
| 285 |
+
torch.from_numpy(curr_state.prefix).cuda(),
|
| 286 |
+
incremental_state=new_incremental_state,
|
| 287 |
+
)
|
| 288 |
+
probs = self.model.get_normalized_probs(
|
| 289 |
+
res, log_probs=True, sample=None
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
if new_incremental_state is not None:
|
| 293 |
+
new_incremental_state = apply_to_sample(
|
| 294 |
+
lambda x: x.cpu(), new_incremental_state
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
curr_state = FairseqLMState(
|
| 298 |
+
curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
if not no_cache:
|
| 302 |
+
self.states[state] = curr_state
|
| 303 |
+
self.stateq.append(state)
|
| 304 |
+
|
| 305 |
+
score = curr_state.probs[token_index].item()
|
| 306 |
+
|
| 307 |
+
trim_cache(self.max_cache)
|
| 308 |
+
|
| 309 |
+
outstate = state.child(token_index)
|
| 310 |
+
if outstate not in self.states and not no_cache:
|
| 311 |
+
prefix = np.concatenate(
|
| 312 |
+
[curr_state.prefix, torch.LongTensor([[token_index]])], -1
|
| 313 |
+
)
|
| 314 |
+
incr_state = curr_state.incremental_state
|
| 315 |
+
|
| 316 |
+
self.states[outstate] = FairseqLMState(prefix, incr_state, None)
|
| 317 |
+
|
| 318 |
+
if token_index == self.unk:
|
| 319 |
+
score = float("-inf")
|
| 320 |
+
|
| 321 |
+
return outstate, score
|
| 322 |
+
|
| 323 |
+
def finish(self, state: LMState):
|
| 324 |
+
"""
|
| 325 |
+
Evaluate eos for language model based on the current lm state
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
--------
|
| 329 |
+
(LMState, float): pair of (new state, score for the current word)
|
| 330 |
+
"""
|
| 331 |
+
return self.score(state, self.dictionary.eos())
|
| 332 |
+
|
| 333 |
+
def empty_cache(self):
|
| 334 |
+
self.states = {}
|
| 335 |
+
self.stateq = deque()
|
| 336 |
+
gc.collect()
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class W2lFairseqLMDecoder(W2lDecoder):
|
| 340 |
+
def __init__(self, args, tgt_dict):
|
| 341 |
+
super().__init__(args, tgt_dict)
|
| 342 |
+
|
| 343 |
+
self.silence = tgt_dict.bos()
|
| 344 |
+
|
| 345 |
+
self.unit_lm = getattr(args, "unit_lm", False)
|
| 346 |
+
|
| 347 |
+
self.lexicon = load_words(args.lexicon) if args.lexicon else None
|
| 348 |
+
self.idx_to_wrd = {}
|
| 349 |
+
|
| 350 |
+
checkpoint = torch.load(args.kenlm_model, map_location="cpu")
|
| 351 |
+
lm_args = checkpoint["args"]
|
| 352 |
+
lm_args.data = osp.dirname(args.kenlm_model)
|
| 353 |
+
print(lm_args)
|
| 354 |
+
task = tasks.setup_task(lm_args)
|
| 355 |
+
model = task.build_model(lm_args)
|
| 356 |
+
model.load_state_dict(checkpoint["model"], strict=False)
|
| 357 |
+
|
| 358 |
+
self.trie = Trie(self.vocab_size, self.silence)
|
| 359 |
+
|
| 360 |
+
self.word_dict = task.dictionary
|
| 361 |
+
self.unk_word = self.word_dict.unk()
|
| 362 |
+
self.lm = FairseqLM(self.word_dict, model)
|
| 363 |
+
|
| 364 |
+
self.decoder_opts = DecoderOptions(
|
| 365 |
+
args.beam,
|
| 366 |
+
int(getattr(args, "beam_size_token", len(tgt_dict))),
|
| 367 |
+
args.beam_threshold,
|
| 368 |
+
args.lm_weight,
|
| 369 |
+
args.word_score,
|
| 370 |
+
args.unk_weight,
|
| 371 |
+
args.sil_weight,
|
| 372 |
+
0,
|
| 373 |
+
False,
|
| 374 |
+
self.criterion_type,
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if self.lexicon:
|
| 378 |
+
start_state = self.lm.start(False)
|
| 379 |
+
for i, (word, spellings) in enumerate(self.lexicon.items()):
|
| 380 |
+
if self.unit_lm:
|
| 381 |
+
word_idx = i
|
| 382 |
+
self.idx_to_wrd[i] = word
|
| 383 |
+
score = 0
|
| 384 |
+
else:
|
| 385 |
+
word_idx = self.word_dict.index(word)
|
| 386 |
+
_, score = self.lm.score(start_state, word_idx, no_cache=True)
|
| 387 |
+
|
| 388 |
+
for spelling in spellings:
|
| 389 |
+
spelling_idxs = [tgt_dict.index(token) for token in spelling]
|
| 390 |
+
assert (
|
| 391 |
+
tgt_dict.unk() not in spelling_idxs
|
| 392 |
+
), f"{spelling} {spelling_idxs}"
|
| 393 |
+
self.trie.insert(spelling_idxs, word_idx, score)
|
| 394 |
+
self.trie.smear(SmearingMode.MAX)
|
| 395 |
+
|
| 396 |
+
self.decoder = LexiconDecoder(
|
| 397 |
+
self.decoder_opts,
|
| 398 |
+
self.trie,
|
| 399 |
+
self.lm,
|
| 400 |
+
self.silence,
|
| 401 |
+
self.blank,
|
| 402 |
+
self.unk_word,
|
| 403 |
+
[],
|
| 404 |
+
self.unit_lm,
|
| 405 |
+
)
|
| 406 |
+
else:
|
| 407 |
+
self.decoder = LexiconFreeDecoder(
|
| 408 |
+
self.decoder_opts, self.lm, self.silence, self.blank, []
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
def decode(self, emissions):
|
| 412 |
+
B, T, N = emissions.size()
|
| 413 |
+
hypos = []
|
| 414 |
+
|
| 415 |
+
def idx_to_word(idx):
|
| 416 |
+
if self.unit_lm:
|
| 417 |
+
return self.idx_to_wrd[idx]
|
| 418 |
+
else:
|
| 419 |
+
return self.word_dict[idx]
|
| 420 |
+
|
| 421 |
+
def make_hypo(result):
|
| 422 |
+
hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
|
| 423 |
+
if self.lexicon:
|
| 424 |
+
hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
|
| 425 |
+
return hypo
|
| 426 |
+
|
| 427 |
+
for b in range(B):
|
| 428 |
+
emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
|
| 429 |
+
results = self.decoder.decode(emissions_ptr, T, N)
|
| 430 |
+
|
| 431 |
+
nbest_results = results[: self.nbest]
|
| 432 |
+
hypos.append([make_hypo(result) for result in nbest_results])
|
| 433 |
+
self.lm.empty_cache()
|
| 434 |
+
|
| 435 |
+
return hypos
|
fairseq-0.10.2/examples/speech_to_text/data_utils.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import csv
|
| 8 |
+
import os
|
| 9 |
+
import os.path as op
|
| 10 |
+
import zipfile
|
| 11 |
+
from functools import reduce
|
| 12 |
+
from glob import glob
|
| 13 |
+
from multiprocessing import cpu_count
|
| 14 |
+
from typing import Any, Dict, List
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import sentencepiece as sp
|
| 18 |
+
from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
|
| 19 |
+
from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
|
| 24 |
+
BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
|
| 25 |
+
EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
|
| 26 |
+
PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def gen_vocab(
|
| 30 |
+
input_path: str,
|
| 31 |
+
output_path_prefix: str,
|
| 32 |
+
model_type="bpe",
|
| 33 |
+
vocab_size=1000,
|
| 34 |
+
):
|
| 35 |
+
# Train SentencePiece Model
|
| 36 |
+
arguments = [
|
| 37 |
+
f"--input={input_path}",
|
| 38 |
+
f"--model_prefix={output_path_prefix}",
|
| 39 |
+
f"--model_type={model_type}",
|
| 40 |
+
f"--vocab_size={vocab_size}",
|
| 41 |
+
"--character_coverage=1.0",
|
| 42 |
+
f"--num_threads={cpu_count()}",
|
| 43 |
+
f"--unk_id={UNK_TOKEN_ID}",
|
| 44 |
+
f"--bos_id={BOS_TOKEN_ID}",
|
| 45 |
+
f"--eos_id={EOS_TOKEN_ID}",
|
| 46 |
+
f"--pad_id={PAD_TOKEN_ID}",
|
| 47 |
+
]
|
| 48 |
+
sp.SentencePieceTrainer.Train(" ".join(arguments))
|
| 49 |
+
# Export fairseq dictionary
|
| 50 |
+
spm = sp.SentencePieceProcessor()
|
| 51 |
+
spm.Load(output_path_prefix + ".model")
|
| 52 |
+
vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
|
| 53 |
+
assert (
|
| 54 |
+
vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
|
| 55 |
+
and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
|
| 56 |
+
and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
|
| 57 |
+
and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
|
| 58 |
+
)
|
| 59 |
+
vocab = {
|
| 60 |
+
i: s
|
| 61 |
+
for i, s in vocab.items()
|
| 62 |
+
if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
|
| 63 |
+
}
|
| 64 |
+
with open(output_path_prefix + ".txt", "w") as f_out:
|
| 65 |
+
for _, s in sorted(vocab.items(), key=lambda x: x[0]):
|
| 66 |
+
f_out.write(f"{s} 1\n")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def extract_fbank_features(
|
| 70 |
+
waveform,
|
| 71 |
+
sample_rate,
|
| 72 |
+
output_path=None,
|
| 73 |
+
n_mel_bins=80,
|
| 74 |
+
apply_utterance_cmvn=True,
|
| 75 |
+
overwrite=False,
|
| 76 |
+
):
|
| 77 |
+
if output_path is not None and op.exists(output_path) and not overwrite:
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
_waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
|
| 81 |
+
_waveform = _waveform.squeeze().numpy()
|
| 82 |
+
|
| 83 |
+
features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
|
| 84 |
+
if features is None:
|
| 85 |
+
features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
|
| 86 |
+
if features is None:
|
| 87 |
+
raise ImportError(
|
| 88 |
+
"Please install pyKaldi or torchaudio to enable "
|
| 89 |
+
"online filterbank feature extraction"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
if apply_utterance_cmvn:
|
| 93 |
+
cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
|
| 94 |
+
features = cmvn(features)
|
| 95 |
+
if output_path is not None:
|
| 96 |
+
np.save(output_path, features)
|
| 97 |
+
else:
|
| 98 |
+
return features
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def create_zip(data_root, zip_path):
|
| 102 |
+
cwd = os.path.abspath(os.curdir)
|
| 103 |
+
os.chdir(data_root)
|
| 104 |
+
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
|
| 105 |
+
for filename in tqdm(glob("*.npy")):
|
| 106 |
+
f.write(filename)
|
| 107 |
+
os.chdir(cwd)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def is_npy_data(data: bytes) -> bool:
|
| 111 |
+
return data[0] == 147 and data[1] == 78
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def get_zip_manifest(zip_root, zip_filename):
|
| 115 |
+
zip_path = op.join(zip_root, zip_filename)
|
| 116 |
+
with zipfile.ZipFile(zip_path, mode="r") as f:
|
| 117 |
+
info = f.infolist()
|
| 118 |
+
manifest = {}
|
| 119 |
+
for i in tqdm(info):
|
| 120 |
+
utt_id = op.splitext(i.filename)[0]
|
| 121 |
+
offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
|
| 122 |
+
manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
|
| 123 |
+
with open(zip_path, "rb") as f:
|
| 124 |
+
f.seek(offset)
|
| 125 |
+
data = f.read(file_size)
|
| 126 |
+
assert len(data) > 1 and is_npy_data(data)
|
| 127 |
+
return manifest
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def gen_config_yaml(
|
| 131 |
+
data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
|
| 132 |
+
):
|
| 133 |
+
assert specaugment_policy in {"lb", "ld"}
|
| 134 |
+
data_root = op.abspath(data_root)
|
| 135 |
+
writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
|
| 136 |
+
writer.set_audio_root(op.abspath(data_root))
|
| 137 |
+
writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
|
| 138 |
+
writer.set_input_channels(1)
|
| 139 |
+
writer.set_input_feat_per_channel(80)
|
| 140 |
+
if specaugment_policy == "lb":
|
| 141 |
+
writer.set_specaugment_lb_policy()
|
| 142 |
+
else:
|
| 143 |
+
writer.set_specaugment_ld_policy()
|
| 144 |
+
writer.set_bpe_tokenizer(
|
| 145 |
+
{
|
| 146 |
+
"bpe": "sentencepiece",
|
| 147 |
+
"sentencepiece_model": op.join(data_root, spm_filename),
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
+
writer.set_feature_transforms("_train", ["specaugment"])
|
| 151 |
+
writer.flush()
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def save_df_to_tsv(dataframe, path):
|
| 155 |
+
dataframe.to_csv(
|
| 156 |
+
path,
|
| 157 |
+
sep="\t",
|
| 158 |
+
header=True,
|
| 159 |
+
index=False,
|
| 160 |
+
encoding="utf-8",
|
| 161 |
+
escapechar="\\",
|
| 162 |
+
quoting=csv.QUOTE_NONE,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def filter_manifest_df(
|
| 167 |
+
df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
|
| 168 |
+
):
|
| 169 |
+
filters = {
|
| 170 |
+
"no speech": df["audio"] == "",
|
| 171 |
+
f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
|
| 172 |
+
"empty sentence": df["tgt_text"] == "",
|
| 173 |
+
}
|
| 174 |
+
if is_train_split:
|
| 175 |
+
filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
|
| 176 |
+
if extra_filters is not None:
|
| 177 |
+
filters.update(extra_filters)
|
| 178 |
+
invalid = reduce(lambda x, y: x | y, filters.values())
|
| 179 |
+
valid = ~invalid
|
| 180 |
+
print(
|
| 181 |
+
"| "
|
| 182 |
+
+ ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
|
| 183 |
+
+ f", total {invalid.sum()} filtered, {valid.sum()} remained."
|
| 184 |
+
)
|
| 185 |
+
return df[valid]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class S2TDataConfigWriter(object):
|
| 189 |
+
DEFAULT_VOCAB_FILENAME = "dict.txt"
|
| 190 |
+
DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
|
| 191 |
+
DEFAULT_INPUT_CHANNELS = 1
|
| 192 |
+
|
| 193 |
+
def __init__(self, yaml_path):
|
| 194 |
+
try:
|
| 195 |
+
import yaml
|
| 196 |
+
except ImportError:
|
| 197 |
+
print("Please install PyYAML to load YAML files for S2T data config")
|
| 198 |
+
self.yaml = yaml
|
| 199 |
+
self.yaml_path = yaml_path
|
| 200 |
+
self.config = {}
|
| 201 |
+
|
| 202 |
+
def flush(self):
|
| 203 |
+
with open(self.yaml_path, "w") as f:
|
| 204 |
+
self.yaml.dump(self.config, f)
|
| 205 |
+
|
| 206 |
+
def set_audio_root(self, audio_root=""):
|
| 207 |
+
self.config["audio_root"] = audio_root
|
| 208 |
+
|
| 209 |
+
def set_vocab_filename(self, vocab_filename="dict.txt"):
|
| 210 |
+
self.config["vocab_filename"] = vocab_filename
|
| 211 |
+
|
| 212 |
+
def set_specaugment(
|
| 213 |
+
self,
|
| 214 |
+
time_wrap_w: int,
|
| 215 |
+
freq_mask_n: int,
|
| 216 |
+
freq_mask_f: int,
|
| 217 |
+
time_mask_n: int,
|
| 218 |
+
time_mask_t: int,
|
| 219 |
+
time_mask_p: float,
|
| 220 |
+
):
|
| 221 |
+
self.config["specaugment"] = {
|
| 222 |
+
"time_wrap_W": time_wrap_w,
|
| 223 |
+
"freq_mask_N": freq_mask_n,
|
| 224 |
+
"freq_mask_F": freq_mask_f,
|
| 225 |
+
"time_mask_N": time_mask_n,
|
| 226 |
+
"time_mask_T": time_mask_t,
|
| 227 |
+
"time_mask_p": time_mask_p,
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
def set_specaugment_lb_policy(self):
|
| 231 |
+
self.set_specaugment(
|
| 232 |
+
time_wrap_w=0,
|
| 233 |
+
freq_mask_n=1,
|
| 234 |
+
freq_mask_f=27,
|
| 235 |
+
time_mask_n=1,
|
| 236 |
+
time_mask_t=100,
|
| 237 |
+
time_mask_p=1.0,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def set_specaugment_ld_policy(self):
|
| 241 |
+
self.set_specaugment(
|
| 242 |
+
time_wrap_w=0,
|
| 243 |
+
freq_mask_n=2,
|
| 244 |
+
freq_mask_f=27,
|
| 245 |
+
time_mask_n=2,
|
| 246 |
+
time_mask_t=100,
|
| 247 |
+
time_mask_p=1.0,
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
def set_input_channels(self, input_channels=1):
|
| 251 |
+
self.config["input_channels"] = input_channels
|
| 252 |
+
|
| 253 |
+
def set_input_feat_per_channel(self, input_feat_per_channel=80):
|
| 254 |
+
self.config["input_feat_per_channel"] = input_feat_per_channel
|
| 255 |
+
|
| 256 |
+
def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
|
| 257 |
+
self.config["bpe_tokenizer"] = bpe_tokenizer
|
| 258 |
+
|
| 259 |
+
def set_feature_transforms(self, split, transforms: List[str]):
|
| 260 |
+
if "transforms" not in self.config:
|
| 261 |
+
self.config["transforms"] = {}
|
| 262 |
+
self.config["transforms"][split] = transforms
|
fairseq-0.10.2/examples/speech_to_text/prep_covost_data.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import csv
|
| 9 |
+
import logging
|
| 10 |
+
import os
|
| 11 |
+
import os.path as op
|
| 12 |
+
import shutil
|
| 13 |
+
from tempfile import NamedTemporaryFile
|
| 14 |
+
from typing import Optional, Tuple
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import torchaudio
|
| 18 |
+
from examples.speech_to_text.data_utils import (
|
| 19 |
+
create_zip,
|
| 20 |
+
extract_fbank_features,
|
| 21 |
+
filter_manifest_df,
|
| 22 |
+
gen_config_yaml,
|
| 23 |
+
gen_vocab,
|
| 24 |
+
get_zip_manifest,
|
| 25 |
+
save_df_to_tsv,
|
| 26 |
+
)
|
| 27 |
+
from torch import Tensor
|
| 28 |
+
from torch.utils.data import Dataset
|
| 29 |
+
from torchaudio.datasets.utils import download_url, extract_archive
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
log = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class CoVoST(Dataset):
|
| 40 |
+
"""Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
root (str): root path to the dataset and generated manifests/features
|
| 44 |
+
source_language (str): source (audio) language
|
| 45 |
+
target_language (str, optional): target (text) language,
|
| 46 |
+
None for no translation (default: None)
|
| 47 |
+
version (int, optional): CoVoST version. (default: 2)
|
| 48 |
+
download (bool, optional): Whether to download the dataset if it is not
|
| 49 |
+
found at root path. (default: ``False``).
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
CV_URL_TEMPLATE = (
|
| 53 |
+
"https://voice-prod-bundler-ee1969a6ce8178826482b88"
|
| 54 |
+
"e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
|
| 55 |
+
)
|
| 56 |
+
COVOST_URL_TEMPLATE = (
|
| 57 |
+
"https://dl.fbaipublicfiles.com/covost/"
|
| 58 |
+
"covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
VERSIONS = {2}
|
| 62 |
+
SPLITS = ["train", "dev", "test"]
|
| 63 |
+
|
| 64 |
+
CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
|
| 65 |
+
|
| 66 |
+
XX_EN_LANGUAGES = {
|
| 67 |
+
1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
|
| 68 |
+
2: [
|
| 69 |
+
"fr",
|
| 70 |
+
"de",
|
| 71 |
+
"es",
|
| 72 |
+
"ca",
|
| 73 |
+
"it",
|
| 74 |
+
"ru",
|
| 75 |
+
"zh-CN",
|
| 76 |
+
"pt",
|
| 77 |
+
"fa",
|
| 78 |
+
"et",
|
| 79 |
+
"mn",
|
| 80 |
+
"nl",
|
| 81 |
+
"tr",
|
| 82 |
+
"ar",
|
| 83 |
+
"sv-SE",
|
| 84 |
+
"lv",
|
| 85 |
+
"sl",
|
| 86 |
+
"ta",
|
| 87 |
+
"ja",
|
| 88 |
+
"id",
|
| 89 |
+
"cy",
|
| 90 |
+
],
|
| 91 |
+
}
|
| 92 |
+
EN_XX_LANGUAGES = {
|
| 93 |
+
1: [],
|
| 94 |
+
2: [
|
| 95 |
+
"de",
|
| 96 |
+
"tr",
|
| 97 |
+
"fa",
|
| 98 |
+
"sv-SE",
|
| 99 |
+
"mn",
|
| 100 |
+
"zh-CN",
|
| 101 |
+
"cy",
|
| 102 |
+
"ca",
|
| 103 |
+
"sl",
|
| 104 |
+
"et",
|
| 105 |
+
"id",
|
| 106 |
+
"ar",
|
| 107 |
+
"ta",
|
| 108 |
+
"lv",
|
| 109 |
+
"ja",
|
| 110 |
+
],
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
root: str,
|
| 116 |
+
split: str,
|
| 117 |
+
source_language: str,
|
| 118 |
+
target_language: Optional[str] = None,
|
| 119 |
+
version: int = 2,
|
| 120 |
+
download: bool = False,
|
| 121 |
+
) -> None:
|
| 122 |
+
assert version in self.VERSIONS and split in self.SPLITS
|
| 123 |
+
assert source_language is not None
|
| 124 |
+
self.no_translation = target_language is None
|
| 125 |
+
if not self.no_translation:
|
| 126 |
+
assert "en" in {source_language, target_language}
|
| 127 |
+
if source_language == "en":
|
| 128 |
+
assert target_language in self.EN_XX_LANGUAGES[version]
|
| 129 |
+
else:
|
| 130 |
+
assert source_language in self.XX_EN_LANGUAGES[version]
|
| 131 |
+
else:
|
| 132 |
+
# Hack here so that we can get "split" column from CoVoST TSV.
|
| 133 |
+
# Note that we use CoVoST train split for ASR which is an extension
|
| 134 |
+
# to Common Voice train split.
|
| 135 |
+
target_language = "de" if source_language == "en" else "en"
|
| 136 |
+
|
| 137 |
+
self.root = os.path.join(root, "raw")
|
| 138 |
+
os.makedirs(self.root, exist_ok=True)
|
| 139 |
+
|
| 140 |
+
cv_url = self.CV_URL_TEMPLATE.format(
|
| 141 |
+
ver=self.CV_VERSION_ID[version], lang=source_language
|
| 142 |
+
)
|
| 143 |
+
cv_archive = os.path.join(self.root, os.path.basename(cv_url))
|
| 144 |
+
if download:
|
| 145 |
+
if not os.path.isfile(cv_archive):
|
| 146 |
+
download_url(cv_url, self.root, hash_value=None)
|
| 147 |
+
extract_archive(cv_archive)
|
| 148 |
+
|
| 149 |
+
covost_url = self.COVOST_URL_TEMPLATE.format(
|
| 150 |
+
src_lang=source_language, tgt_lang=target_language
|
| 151 |
+
)
|
| 152 |
+
covost_archive = os.path.join(self.root, os.path.basename(covost_url))
|
| 153 |
+
if download:
|
| 154 |
+
if not os.path.isfile(covost_archive):
|
| 155 |
+
download_url(covost_url, self.root, hash_value=None)
|
| 156 |
+
extract_archive(covost_archive)
|
| 157 |
+
|
| 158 |
+
cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
|
| 159 |
+
covost_tsv = self.load_from_tsv(
|
| 160 |
+
os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
|
| 161 |
+
)
|
| 162 |
+
df = pd.merge(
|
| 163 |
+
left=cv_tsv[["path", "sentence", "client_id"]],
|
| 164 |
+
right=covost_tsv[["path", "translation", "split"]],
|
| 165 |
+
how="inner",
|
| 166 |
+
on="path",
|
| 167 |
+
)
|
| 168 |
+
if split == "train":
|
| 169 |
+
df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
|
| 170 |
+
else:
|
| 171 |
+
df = df[df["split"] == split]
|
| 172 |
+
self.data = df.to_dict(orient="index").items()
|
| 173 |
+
self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
|
| 174 |
+
|
| 175 |
+
@classmethod
|
| 176 |
+
def load_from_tsv(cls, path: str):
|
| 177 |
+
return pd.read_csv(
|
| 178 |
+
path,
|
| 179 |
+
sep="\t",
|
| 180 |
+
header=0,
|
| 181 |
+
encoding="utf-8",
|
| 182 |
+
escapechar="\\",
|
| 183 |
+
quoting=csv.QUOTE_NONE,
|
| 184 |
+
na_filter=False,
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
def __getitem__(
|
| 188 |
+
self, n: int
|
| 189 |
+
) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
|
| 190 |
+
"""Load the n-th sample from the dataset.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
n (int): The index of the sample to be loaded
|
| 194 |
+
|
| 195 |
+
Returns:
|
| 196 |
+
tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
|
| 197 |
+
sample_id)``
|
| 198 |
+
"""
|
| 199 |
+
data = self.data[n]
|
| 200 |
+
path = os.path.join(self.root, "clips", data["path"])
|
| 201 |
+
waveform, sample_rate = torchaudio.load(path)
|
| 202 |
+
sentence = data["sentence"]
|
| 203 |
+
translation = None if self.no_translation else data["translation"]
|
| 204 |
+
speaker_id = data["client_id"]
|
| 205 |
+
_id = data["path"].replace(".mp3", "")
|
| 206 |
+
return waveform, sample_rate, sentence, translation, speaker_id, _id
|
| 207 |
+
|
| 208 |
+
def __len__(self) -> int:
|
| 209 |
+
return len(self.data)
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def process(args):
|
| 213 |
+
root = op.join(args.data_root, args.src_lang)
|
| 214 |
+
os.makedirs(root, exist_ok=True)
|
| 215 |
+
# Extract features
|
| 216 |
+
feature_root = op.join(root, "fbank80")
|
| 217 |
+
os.makedirs(feature_root, exist_ok=True)
|
| 218 |
+
for split in CoVoST.SPLITS:
|
| 219 |
+
print(f"Fetching split {split}...")
|
| 220 |
+
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
|
| 221 |
+
print("Extracting log mel filter bank features...")
|
| 222 |
+
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
| 223 |
+
extract_fbank_features(
|
| 224 |
+
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
|
| 225 |
+
)
|
| 226 |
+
# Pack features into ZIP
|
| 227 |
+
zip_filename = "fbank80.zip"
|
| 228 |
+
zip_path = op.join(root, zip_filename)
|
| 229 |
+
print("ZIPing features...")
|
| 230 |
+
create_zip(feature_root, zip_path)
|
| 231 |
+
print("Fetching ZIP manifest...")
|
| 232 |
+
zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
|
| 233 |
+
# Generate TSV manifest
|
| 234 |
+
print("Generating manifest...")
|
| 235 |
+
train_text = []
|
| 236 |
+
task = f"asr_{args.src_lang}"
|
| 237 |
+
if args.tgt_lang is not None:
|
| 238 |
+
task = f"st_{args.src_lang}_{args.tgt_lang}"
|
| 239 |
+
for split in CoVoST.SPLITS:
|
| 240 |
+
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
| 241 |
+
dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
|
| 242 |
+
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
| 243 |
+
manifest["id"].append(utt_id)
|
| 244 |
+
manifest["audio"].append(zip_manifest[utt_id])
|
| 245 |
+
duration_ms = int(wav.size(1) / sr * 1000)
|
| 246 |
+
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
| 247 |
+
manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
|
| 248 |
+
manifest["speaker"].append(speaker_id)
|
| 249 |
+
is_train_split = split.startswith("train")
|
| 250 |
+
if is_train_split:
|
| 251 |
+
train_text.extend(manifest["tgt_text"])
|
| 252 |
+
df = pd.DataFrame.from_dict(manifest)
|
| 253 |
+
df = filter_manifest_df(df, is_train_split=is_train_split)
|
| 254 |
+
save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
|
| 255 |
+
# Generate vocab
|
| 256 |
+
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
|
| 257 |
+
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
|
| 258 |
+
with NamedTemporaryFile(mode="w") as f:
|
| 259 |
+
for t in train_text:
|
| 260 |
+
f.write(t + "\n")
|
| 261 |
+
gen_vocab(
|
| 262 |
+
f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
|
| 263 |
+
)
|
| 264 |
+
# Generate config YAML
|
| 265 |
+
gen_config_yaml(
|
| 266 |
+
root,
|
| 267 |
+
spm_filename_prefix + ".model",
|
| 268 |
+
yaml_filename=f"config_{task}.yaml",
|
| 269 |
+
specaugment_policy="lb",
|
| 270 |
+
)
|
| 271 |
+
# Clean up
|
| 272 |
+
shutil.rmtree(feature_root)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def main():
|
| 276 |
+
parser = argparse.ArgumentParser()
|
| 277 |
+
parser.add_argument("--data-root", "-d", required=True, type=str)
|
| 278 |
+
parser.add_argument(
|
| 279 |
+
"--vocab-type",
|
| 280 |
+
default="unigram",
|
| 281 |
+
required=True,
|
| 282 |
+
type=str,
|
| 283 |
+
choices=["bpe", "unigram", "char"],
|
| 284 |
+
),
|
| 285 |
+
parser.add_argument("--vocab-size", default=1000, type=int)
|
| 286 |
+
parser.add_argument("--src-lang", "-s", required=True, type=str)
|
| 287 |
+
parser.add_argument("--tgt-lang", "-t", type=str)
|
| 288 |
+
args = parser.parse_args()
|
| 289 |
+
|
| 290 |
+
process(args)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
if __name__ == "__main__":
|
| 294 |
+
main()
|
fairseq-0.10.2/examples/speech_to_text/prep_librispeech_data.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import os.path as op
|
| 11 |
+
import shutil
|
| 12 |
+
from tempfile import NamedTemporaryFile
|
| 13 |
+
|
| 14 |
+
import pandas as pd
|
| 15 |
+
from examples.speech_to_text.data_utils import (
|
| 16 |
+
create_zip,
|
| 17 |
+
extract_fbank_features,
|
| 18 |
+
gen_config_yaml,
|
| 19 |
+
gen_vocab,
|
| 20 |
+
get_zip_manifest,
|
| 21 |
+
save_df_to_tsv,
|
| 22 |
+
)
|
| 23 |
+
from torchaudio.datasets import LIBRISPEECH
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
log = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
SPLITS = [
|
| 30 |
+
"train-clean-100",
|
| 31 |
+
"train-clean-360",
|
| 32 |
+
"train-other-500",
|
| 33 |
+
"dev-clean",
|
| 34 |
+
"dev-other",
|
| 35 |
+
"test-clean",
|
| 36 |
+
"test-other",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def process(args):
|
| 43 |
+
os.makedirs(args.output_root, exist_ok=True)
|
| 44 |
+
# Extract features
|
| 45 |
+
feature_root = op.join(args.output_root, "fbank80")
|
| 46 |
+
os.makedirs(feature_root, exist_ok=True)
|
| 47 |
+
for split in SPLITS:
|
| 48 |
+
print(f"Fetching split {split}...")
|
| 49 |
+
dataset = LIBRISPEECH(args.output_root, url=split, download=True)
|
| 50 |
+
print("Extracting log mel filter bank features...")
|
| 51 |
+
for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
|
| 52 |
+
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
|
| 53 |
+
extract_fbank_features(
|
| 54 |
+
wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
|
| 55 |
+
)
|
| 56 |
+
# Pack features into ZIP
|
| 57 |
+
zip_filename = "fbank80.zip"
|
| 58 |
+
zip_path = op.join(args.output_root, zip_filename)
|
| 59 |
+
print("ZIPing features...")
|
| 60 |
+
create_zip(feature_root, zip_path)
|
| 61 |
+
print("Fetching ZIP manifest...")
|
| 62 |
+
zip_manifest = get_zip_manifest(args.output_root, zip_filename)
|
| 63 |
+
# Generate TSV manifest
|
| 64 |
+
print("Generating manifest...")
|
| 65 |
+
train_text = []
|
| 66 |
+
for split in SPLITS:
|
| 67 |
+
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
| 68 |
+
dataset = LIBRISPEECH(args.output_root, url=split)
|
| 69 |
+
for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
|
| 70 |
+
sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
|
| 71 |
+
manifest["id"].append(sample_id)
|
| 72 |
+
manifest["audio"].append(zip_manifest[sample_id])
|
| 73 |
+
duration_ms = int(wav.size(1) / sample_rate * 1000)
|
| 74 |
+
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
| 75 |
+
manifest["tgt_text"].append(utt)
|
| 76 |
+
manifest["speaker"].append(spk_id)
|
| 77 |
+
save_df_to_tsv(
|
| 78 |
+
pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
|
| 79 |
+
)
|
| 80 |
+
if split.startswith("train"):
|
| 81 |
+
train_text.extend(manifest["tgt_text"])
|
| 82 |
+
# Generate vocab
|
| 83 |
+
vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
|
| 84 |
+
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
|
| 85 |
+
with NamedTemporaryFile(mode="w") as f:
|
| 86 |
+
for t in train_text:
|
| 87 |
+
f.write(t + "\n")
|
| 88 |
+
gen_vocab(
|
| 89 |
+
f.name,
|
| 90 |
+
op.join(args.output_root, spm_filename_prefix),
|
| 91 |
+
args.vocab_type,
|
| 92 |
+
args.vocab_size,
|
| 93 |
+
)
|
| 94 |
+
# Generate config YAML
|
| 95 |
+
gen_config_yaml(
|
| 96 |
+
args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
|
| 97 |
+
)
|
| 98 |
+
# Clean up
|
| 99 |
+
shutil.rmtree(feature_root)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def main():
|
| 103 |
+
parser = argparse.ArgumentParser()
|
| 104 |
+
parser.add_argument("--output-root", "-o", required=True, type=str)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--vocab-type",
|
| 107 |
+
default="unigram",
|
| 108 |
+
required=True,
|
| 109 |
+
type=str,
|
| 110 |
+
choices=["bpe", "unigram", "char"],
|
| 111 |
+
),
|
| 112 |
+
parser.add_argument("--vocab-size", default=10000, type=int)
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
process(args)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
if __name__ == "__main__":
|
| 119 |
+
main()
|
fairseq-0.10.2/examples/speech_to_text/prep_mustc_data.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import logging
|
| 9 |
+
import os
|
| 10 |
+
import os.path as op
|
| 11 |
+
import shutil
|
| 12 |
+
from itertools import groupby
|
| 13 |
+
from tempfile import NamedTemporaryFile
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
|
| 16 |
+
import pandas as pd
|
| 17 |
+
import torchaudio
|
| 18 |
+
from examples.speech_to_text.data_utils import (
|
| 19 |
+
create_zip,
|
| 20 |
+
extract_fbank_features,
|
| 21 |
+
filter_manifest_df,
|
| 22 |
+
gen_config_yaml,
|
| 23 |
+
gen_vocab,
|
| 24 |
+
get_zip_manifest,
|
| 25 |
+
save_df_to_tsv,
|
| 26 |
+
)
|
| 27 |
+
from torch import Tensor
|
| 28 |
+
from torch.utils.data import Dataset
|
| 29 |
+
from tqdm import tqdm
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
log = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
|
| 36 |
+
TASKS = ["asr", "st"]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MUSTC(Dataset):
|
| 40 |
+
"""
|
| 41 |
+
Create a Dataset for MuST-C. Each item is a tuple of the form:
|
| 42 |
+
waveform, sample_rate, source utterance, target utterance, speaker_id,
|
| 43 |
+
utterance_id
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
|
| 47 |
+
LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
|
| 48 |
+
|
| 49 |
+
def __init__(self, root: str, lang: str, split: str) -> None:
|
| 50 |
+
assert split in self.SPLITS and lang in self.LANGUAGES
|
| 51 |
+
_root = op.join(root, f"en-{lang}", "data", split)
|
| 52 |
+
wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
|
| 53 |
+
assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
|
| 54 |
+
# Load audio segments
|
| 55 |
+
try:
|
| 56 |
+
import yaml
|
| 57 |
+
except ImportError:
|
| 58 |
+
print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
|
| 59 |
+
with open(op.join(txt_root, f"{split}.yaml")) as f:
|
| 60 |
+
segments = yaml.load(f, Loader=yaml.BaseLoader)
|
| 61 |
+
# Load source and target utterances
|
| 62 |
+
for _lang in ["en", lang]:
|
| 63 |
+
with open(op.join(txt_root, f"{split}.{_lang}")) as f:
|
| 64 |
+
utterances = [r.strip() for r in f]
|
| 65 |
+
assert len(segments) == len(utterances)
|
| 66 |
+
for i, u in enumerate(utterances):
|
| 67 |
+
segments[i][_lang] = u
|
| 68 |
+
# Gather info
|
| 69 |
+
self.data = []
|
| 70 |
+
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
|
| 71 |
+
wav_path = op.join(wav_root, wav_filename)
|
| 72 |
+
sample_rate = torchaudio.info(wav_path)[0].rate
|
| 73 |
+
seg_group = sorted(_seg_group, key=lambda x: x["offset"])
|
| 74 |
+
for i, segment in enumerate(seg_group):
|
| 75 |
+
offset = int(float(segment["offset"]) * sample_rate)
|
| 76 |
+
n_frames = int(float(segment["duration"]) * sample_rate)
|
| 77 |
+
_id = f"{op.splitext(wav_filename)[0]}_{i}"
|
| 78 |
+
self.data.append(
|
| 79 |
+
(
|
| 80 |
+
wav_path,
|
| 81 |
+
offset,
|
| 82 |
+
n_frames,
|
| 83 |
+
sample_rate,
|
| 84 |
+
segment["en"],
|
| 85 |
+
segment[lang],
|
| 86 |
+
segment["speaker_id"],
|
| 87 |
+
_id,
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
|
| 92 |
+
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
|
| 93 |
+
waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
|
| 94 |
+
return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
|
| 95 |
+
|
| 96 |
+
def __len__(self) -> int:
|
| 97 |
+
return len(self.data)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def process(args):
|
| 101 |
+
for lang in MUSTC.LANGUAGES:
|
| 102 |
+
cur_root = op.join(args.data_root, f"en-{lang}")
|
| 103 |
+
if not op.isdir(cur_root):
|
| 104 |
+
print(f"{cur_root} does not exist. Skipped.")
|
| 105 |
+
continue
|
| 106 |
+
# Extract features
|
| 107 |
+
feature_root = op.join(cur_root, "fbank80")
|
| 108 |
+
os.makedirs(feature_root, exist_ok=True)
|
| 109 |
+
for split in MUSTC.SPLITS:
|
| 110 |
+
print(f"Fetching split {split}...")
|
| 111 |
+
dataset = MUSTC(args.data_root, lang, split)
|
| 112 |
+
print("Extracting log mel filter bank features...")
|
| 113 |
+
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
|
| 114 |
+
extract_fbank_features(
|
| 115 |
+
waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
|
| 116 |
+
)
|
| 117 |
+
# Pack features into ZIP
|
| 118 |
+
zip_filename = "fbank80.zip"
|
| 119 |
+
zip_path = op.join(cur_root, zip_filename)
|
| 120 |
+
print("ZIPing features...")
|
| 121 |
+
create_zip(feature_root, zip_path)
|
| 122 |
+
print("Fetching ZIP manifest...")
|
| 123 |
+
zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
|
| 124 |
+
# Generate TSV manifest
|
| 125 |
+
print("Generating manifest...")
|
| 126 |
+
train_text = {task: [] for task in TASKS}
|
| 127 |
+
for split in MUSTC.SPLITS:
|
| 128 |
+
is_train_split = split.startswith("train")
|
| 129 |
+
manifest = {c: [] for c in MANIFEST_COLUMNS}
|
| 130 |
+
text = {task: [] for task in TASKS}
|
| 131 |
+
dataset = MUSTC(args.data_root, lang, split)
|
| 132 |
+
for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
|
| 133 |
+
manifest["id"].append(utt_id)
|
| 134 |
+
manifest["audio"].append(zip_manifest[utt_id])
|
| 135 |
+
duration_ms = int(wav.size(1) / sr * 1000)
|
| 136 |
+
manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
|
| 137 |
+
text["asr"].append(src_utt)
|
| 138 |
+
text["st"].append(tgt_utt)
|
| 139 |
+
manifest["speaker"].append(speaker_id)
|
| 140 |
+
if is_train_split:
|
| 141 |
+
for task in TASKS:
|
| 142 |
+
train_text[task].extend(text[task])
|
| 143 |
+
for task in TASKS:
|
| 144 |
+
manifest["tgt_text"] = text[task]
|
| 145 |
+
df = pd.DataFrame.from_dict(manifest)
|
| 146 |
+
df = filter_manifest_df(df, is_train_split=is_train_split)
|
| 147 |
+
save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
|
| 148 |
+
# Generate vocab
|
| 149 |
+
for task in TASKS:
|
| 150 |
+
vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
|
| 151 |
+
if task == "st":
|
| 152 |
+
vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
|
| 153 |
+
vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
|
| 154 |
+
spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
|
| 155 |
+
with NamedTemporaryFile(mode="w") as f:
|
| 156 |
+
for t in train_text[task]:
|
| 157 |
+
f.write(t + "\n")
|
| 158 |
+
gen_vocab(
|
| 159 |
+
f.name,
|
| 160 |
+
op.join(cur_root, spm_filename_prefix),
|
| 161 |
+
vocab_type,
|
| 162 |
+
vocab_size,
|
| 163 |
+
)
|
| 164 |
+
# Generate config YAML
|
| 165 |
+
gen_config_yaml(
|
| 166 |
+
cur_root,
|
| 167 |
+
spm_filename_prefix + ".model",
|
| 168 |
+
yaml_filename=f"config_{task}.yaml",
|
| 169 |
+
specaugment_policy="lb",
|
| 170 |
+
)
|
| 171 |
+
# Clean up
|
| 172 |
+
shutil.rmtree(feature_root)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def main():
|
| 176 |
+
parser = argparse.ArgumentParser()
|
| 177 |
+
parser.add_argument("--data-root", "-d", required=True, type=str)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--asr-vocab-type",
|
| 180 |
+
default="unigram",
|
| 181 |
+
required=True,
|
| 182 |
+
type=str,
|
| 183 |
+
choices=["bpe", "unigram", "char"],
|
| 184 |
+
),
|
| 185 |
+
parser.add_argument(
|
| 186 |
+
"--st-vocab-type",
|
| 187 |
+
default="unigram",
|
| 188 |
+
required=True,
|
| 189 |
+
type=str,
|
| 190 |
+
choices=["bpe", "unigram", "char"],
|
| 191 |
+
),
|
| 192 |
+
parser.add_argument("--asr-vocab-size", default=5000, type=int)
|
| 193 |
+
parser.add_argument("--st-vocab-size", default=8000, type=int)
|
| 194 |
+
args = parser.parse_args()
|
| 195 |
+
|
| 196 |
+
process(args)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
main()
|
fairseq-0.10.2/examples/unsupervised_quality_estimation/README.md
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)
|
| 2 |
+
|
| 3 |
+
This page includes instructions for reproducing results from the paper [Unsupervised Quality Estimation for Neural
|
| 4 |
+
Machine Translation (Fomicheva et al., 2020)](https://arxiv.org/abs/2005.10608)
|
| 5 |
+
|
| 6 |
+
## Requirements:
|
| 7 |
+
|
| 8 |
+
* mosesdecoder: https://github.com/moses-smt/mosesdecoder
|
| 9 |
+
* subword-nmt: https://github.com/rsennrich/subword-nmt
|
| 10 |
+
* flores: https://github.com/facebookresearch/flores
|
| 11 |
+
|
| 12 |
+
## Download Models and Test Data
|
| 13 |
+
|
| 14 |
+
Download translation models and test data from [MLQE dataset repository](https://github.com/facebookresearch/mlqe).
|
| 15 |
+
|
| 16 |
+
## Set up:
|
| 17 |
+
|
| 18 |
+
Given a testset consisting of source sentences and reference translations:
|
| 19 |
+
|
| 20 |
+
* `SRC_LANG`: source language
|
| 21 |
+
* `TGT_LANG`: target language
|
| 22 |
+
* `INPUT`: input prefix, such that the file `$INPUT.$SRC_LANG` contains source sentences and `$INPUT.$TGT_LANG`
|
| 23 |
+
contains the reference sentences
|
| 24 |
+
* `OUTPUT_DIR`: output path to store results
|
| 25 |
+
* `MOSES_DECODER`: path to mosesdecoder installation
|
| 26 |
+
* `BPE_ROOT`: path to subword-nmt installation
|
| 27 |
+
* `BPE`: path to BPE model
|
| 28 |
+
* `MODEL_DIR`: directory containing the NMT model `.pt` file as well as the source and target vocabularies.
|
| 29 |
+
* `TMP`: directory for intermediate temporary files
|
| 30 |
+
* `GPU`: if translating with GPU, id of the GPU to use for inference
|
| 31 |
+
* `DROPOUT_N`: number of stochastic forward passes
|
| 32 |
+
|
| 33 |
+
`$DROPOUT_N` is set to 30 in the experiments reported in the paper. However, we observed that increasing it beyond 10
|
| 34 |
+
does not bring substantial improvements.
|
| 35 |
+
|
| 36 |
+
## Translate the data using standard decoding
|
| 37 |
+
|
| 38 |
+
Preprocess the input data:
|
| 39 |
+
```
|
| 40 |
+
for LANG in $SRC_LANG $TGT_LANG; do
|
| 41 |
+
perl $MOSES_DECODER/scripts/tokenizer/tokenizer.perl -threads 80 -a -l $LANG < $INPUT.$LANG > $TMP/preprocessed.tok.$LANG
|
| 42 |
+
python $BPE_ROOT/apply_bpe.py -c ${BPE} < $TMP/preprocessed.tok.$LANG > $TMP/preprocessed.tok.bpe.$LANG
|
| 43 |
+
done
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Binarize the data for faster translation:
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
fairseq-preprocess --srcdict $MODEL_DIR/dict.$SRC_LANG.txt --tgtdict $MODEL_DIR/dict.$TGT_LANG.txt
|
| 50 |
+
--source-lang ${SRC_LANG} --target-lang ${TGT_LANG} --testpref $TMP/preprocessed.tok.bpe --destdir $TMP/bin --workers 4
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
Translate
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5
|
| 57 |
+
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out
|
| 58 |
+
grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Post-process
|
| 62 |
+
|
| 63 |
+
```
|
| 64 |
+
sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/mt.out | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl
|
| 65 |
+
-l $TGT_LANG > $OUTPUT_DIR/mt.out
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
## Produce uncertainty estimates
|
| 69 |
+
|
| 70 |
+
### Scoring
|
| 71 |
+
|
| 72 |
+
Make temporary files to store the translations repeated N times.
|
| 73 |
+
|
| 74 |
+
```
|
| 75 |
+
python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/preprocessed.tok.bpe.$SRC_LANG -n $DROPOUT_N
|
| 76 |
+
-o $TMP/repeated.$SRC_LANG
|
| 77 |
+
python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/mt.out -n $DROPOUT_N -o $TMP/repeated.$TGT_LANG
|
| 78 |
+
|
| 79 |
+
fairseq-preprocess --srcdict ${MODEL_DIR}/dict.${SRC_LANG}.txt $TGT_DIC --source-lang ${SRC_LANG}
|
| 80 |
+
--target-lang ${TGT_LANG} --testpref ${TMP}/repeated --destdir ${TMP}/bin-repeated
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
Produce model scores for the generated translations using `--retain-dropout` option to apply dropout at inference time:
|
| 84 |
+
|
| 85 |
+
```
|
| 86 |
+
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5
|
| 87 |
+
--source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout
|
| 88 |
+
--retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer
|
| 89 |
+
TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out
|
| 90 |
+
|
| 91 |
+
grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Use `--retain-dropout-modules` to specify the modules. By default, dropout is applied in the same places
|
| 96 |
+
as for training.
|
| 97 |
+
|
| 98 |
+
Compute the mean of the resulting output distribution:
|
| 99 |
+
|
| 100 |
+
```
|
| 101 |
+
python $SCRIPTS/scripts/uncertainty/aggregate_scores.py -i $TMP/dropout.scores -o $OUTPUT_DIR/dropout.scores.mean
|
| 102 |
+
-n $DROPOUT_N
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
### Generation
|
| 106 |
+
|
| 107 |
+
Produce multiple translation hypotheses for the same source using `--retain-dropout` option:
|
| 108 |
+
|
| 109 |
+
```
|
| 110 |
+
CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt
|
| 111 |
+
--beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --retain-dropout
|
| 112 |
+
--unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder
|
| 113 |
+
TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out
|
| 114 |
+
|
| 115 |
+
grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_
|
| 116 |
+
|
| 117 |
+
sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl
|
| 118 |
+
-l $TGT_LANG > $TMP/dropout.hypotheses
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
Compute similarity between multiple hypotheses corresponding to the same source sentence using Meteor
|
| 122 |
+
evaluation metric:
|
| 123 |
+
```
|
| 124 |
+
python meteor.py -i $TMP/dropout.hypotheses -m <path_to_meteor_installation> -n $DROPOUT_N -o
|
| 125 |
+
$OUTPUT_DIR/dropout.gen.sim.meteor
|
| 126 |
+
```
|