sleepyhead111 commited on
Commit
fbd4c5c
·
verified ·
1 Parent(s): b12c168

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq-0.10.2/examples/backtranslation/README.md +297 -0
  2. fairseq-0.10.2/examples/backtranslation/deduplicate_lines.py +41 -0
  3. fairseq-0.10.2/examples/backtranslation/extract_bt_data.py +72 -0
  4. fairseq-0.10.2/examples/backtranslation/prepare-de-monolingual.sh +98 -0
  5. fairseq-0.10.2/examples/backtranslation/sacrebleu.sh +37 -0
  6. fairseq-0.10.2/examples/backtranslation/tokenized_bleu.sh +46 -0
  7. fairseq-0.10.2/examples/bart/README.glue.md +99 -0
  8. fairseq-0.10.2/examples/bart/README.md +218 -0
  9. fairseq-0.10.2/examples/bart/README.summarization.md +121 -0
  10. fairseq-0.10.2/examples/criss/download_and_preprocess_flores_test.sh +64 -0
  11. fairseq-0.10.2/examples/criss/download_and_preprocess_tatoeba.sh +37 -0
  12. fairseq-0.10.2/examples/criss/mining/mine.py +233 -0
  13. fairseq-0.10.2/examples/criss/mining/mine_example.sh +103 -0
  14. fairseq-0.10.2/examples/criss/save_encoder.py +213 -0
  15. fairseq-0.10.2/examples/criss/sentence_retrieval/encoder_analysis.py +92 -0
  16. fairseq-0.10.2/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh +59 -0
  17. fairseq-0.10.2/examples/criss/unsupervised_mt/eval.sh +37 -0
  18. fairseq-0.10.2/examples/cross_lingual_language_model/README.md +77 -0
  19. fairseq-0.10.2/examples/latent_depth/latent_depth_src/loss/latent_depth.py +99 -0
  20. fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/__init__.py +0 -0
  21. fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py +59 -0
  22. fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_transformer.py +146 -0
  23. fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/__init__.py +0 -0
  24. fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/latent_layers.py +86 -0
  25. fairseq-0.10.2/examples/m2m_100/README.md +209 -0
  26. fairseq-0.10.2/examples/m2m_100/install_dependecies.sh +78 -0
  27. fairseq-0.10.2/examples/m2m_100/process_data/remove_too_much_punc.py +36 -0
  28. fairseq-0.10.2/examples/m2m_100/tok.sh +83 -0
  29. fairseq-0.10.2/examples/mbart/README.md +123 -0
  30. fairseq-0.10.2/examples/roberta/README.glue.md +99 -0
  31. fairseq-0.10.2/examples/roberta/README.md +296 -0
  32. fairseq-0.10.2/examples/roberta/README.pretraining.md +98 -0
  33. fairseq-0.10.2/examples/roberta/README.race.md +68 -0
  34. fairseq-0.10.2/examples/roberta/multiprocessing_bpe_encoder.py +130 -0
  35. fairseq-0.10.2/examples/roberta/preprocess_GLUE_tasks.sh +185 -0
  36. fairseq-0.10.2/examples/roberta/preprocess_RACE.sh +59 -0
  37. fairseq-0.10.2/examples/speech_recognition/README.md +106 -0
  38. fairseq-0.10.2/examples/speech_recognition/__init__.py +1 -0
  39. fairseq-0.10.2/examples/speech_recognition/criterions/__init__.py +17 -0
  40. fairseq-0.10.2/examples/speech_recognition/criterions/cross_entropy_acc.py +130 -0
  41. fairseq-0.10.2/examples/speech_recognition/datasets/asr_prep_json.py +125 -0
  42. fairseq-0.10.2/examples/speech_recognition/datasets/prepare-librispeech.sh +88 -0
  43. fairseq-0.10.2/examples/speech_recognition/infer.py +464 -0
  44. fairseq-0.10.2/examples/speech_recognition/utils/wer_utils.py +381 -0
  45. fairseq-0.10.2/examples/speech_recognition/w2l_decoder.py +435 -0
  46. fairseq-0.10.2/examples/speech_to_text/data_utils.py +262 -0
  47. fairseq-0.10.2/examples/speech_to_text/prep_covost_data.py +294 -0
  48. fairseq-0.10.2/examples/speech_to_text/prep_librispeech_data.py +119 -0
  49. fairseq-0.10.2/examples/speech_to_text/prep_mustc_data.py +200 -0
  50. fairseq-0.10.2/examples/unsupervised_quality_estimation/README.md +126 -0
fairseq-0.10.2/examples/backtranslation/README.md ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Understanding Back-Translation at Scale (Edunov et al., 2018)
2
+
3
+ This page includes pre-trained models from the paper [Understanding Back-Translation at Scale (Edunov et al., 2018)](https://arxiv.org/abs/1808.09381).
4
+
5
+ ## Pre-trained models
6
+
7
+ Model | Description | Dataset | Download
8
+ ---|---|---|---
9
+ `transformer.wmt18.en-de` | Transformer <br> ([Edunov et al., 2018](https://arxiv.org/abs/1808.09381)) <br> WMT'18 winner | [WMT'18 English-German](http://www.statmt.org/wmt18/translation-task.html) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/wmt18.en-de.ensemble.tar.gz) <br> See NOTE in the archive
10
+
11
+ ## Example usage (torch.hub)
12
+
13
+ We require a few additional Python dependencies for preprocessing:
14
+ ```bash
15
+ pip install subword_nmt sacremoses
16
+ ```
17
+
18
+ Then to generate translations from the full model ensemble:
19
+ ```python
20
+ import torch
21
+
22
+ # List available models
23
+ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt18.en-de', ... ]
24
+
25
+ # Load the WMT'18 En-De ensemble
26
+ en2de_ensemble = torch.hub.load(
27
+ 'pytorch/fairseq', 'transformer.wmt18.en-de',
28
+ checkpoint_file='wmt18.model1.pt:wmt18.model2.pt:wmt18.model3.pt:wmt18.model4.pt:wmt18.model5.pt',
29
+ tokenizer='moses', bpe='subword_nmt')
30
+
31
+ # The ensemble contains 5 models
32
+ len(en2de_ensemble.models)
33
+ # 5
34
+
35
+ # Translate
36
+ en2de_ensemble.translate('Hello world!')
37
+ # 'Hallo Welt!'
38
+ ```
39
+
40
+ ## Training your own model (WMT'18 English-German)
41
+
42
+ The following instructions can be adapted to reproduce the models from the paper.
43
+
44
+
45
+ #### Step 1. Prepare parallel data and optionally train a baseline (English-German) model
46
+
47
+ First download and preprocess the data:
48
+ ```bash
49
+ # Download and prepare the data
50
+ cd examples/backtranslation/
51
+ bash prepare-wmt18en2de.sh
52
+ cd ../..
53
+
54
+ # Binarize the data
55
+ TEXT=examples/backtranslation/wmt18_en_de
56
+ fairseq-preprocess \
57
+ --joined-dictionary \
58
+ --source-lang en --target-lang de \
59
+ --trainpref $TEXT/train --validpref $TEXT/valid --testpref $TEXT/test \
60
+ --destdir data-bin/wmt18_en_de --thresholdtgt 0 --thresholdsrc 0 \
61
+ --workers 20
62
+
63
+ # Copy the BPE code into the data-bin directory for future use
64
+ cp examples/backtranslation/wmt18_en_de/code data-bin/wmt18_en_de/code
65
+ ```
66
+
67
+ (Optionally) Train a baseline model (English-German) using just the parallel data:
68
+ ```bash
69
+ CHECKPOINT_DIR=checkpoints_en_de_parallel
70
+ fairseq-train --fp16 \
71
+ data-bin/wmt18_en_de \
72
+ --source-lang en --target-lang de \
73
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
74
+ --dropout 0.3 --weight-decay 0.0 \
75
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
76
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
77
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
78
+ --max-tokens 3584 --update-freq 16 \
79
+ --max-update 30000 \
80
+ --save-dir $CHECKPOINT_DIR
81
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
82
+ # different number of GPUs.
83
+ ```
84
+
85
+ Average the last 10 checkpoints:
86
+ ```bash
87
+ python scripts/average_checkpoints.py \
88
+ --inputs $CHECKPOINT_DIR \
89
+ --num-epoch-checkpoints 10 \
90
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
91
+ ```
92
+
93
+ Evaluate BLEU:
94
+ ```bash
95
+ # tokenized BLEU on newstest2017:
96
+ bash examples/backtranslation/tokenized_bleu.sh \
97
+ wmt17 \
98
+ en-de \
99
+ data-bin/wmt18_en_de \
100
+ data-bin/wmt18_en_de/code \
101
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
102
+ # BLEU4 = 29.57, 60.9/35.4/22.9/15.5 (BP=1.000, ratio=1.014, syslen=63049, reflen=62152)
103
+ # compare to 29.46 in Table 1, which is also for tokenized BLEU
104
+
105
+ # generally it's better to report (detokenized) sacrebleu though:
106
+ bash examples/backtranslation/sacrebleu.sh \
107
+ wmt17 \
108
+ en-de \
109
+ data-bin/wmt18_en_de \
110
+ data-bin/wmt18_en_de/code \
111
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
112
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 29.0 60.6/34.7/22.4/14.9 (BP = 1.000 ratio = 1.013 hyp_len = 62099 ref_len = 61287)
113
+ ```
114
+
115
+
116
+ #### Step 2. Back-translate monolingual German data
117
+
118
+ Train a reverse model (German-English) to do the back-translation:
119
+ ```bash
120
+ CHECKPOINT_DIR=checkpoints_de_en_parallel
121
+ fairseq-train --fp16 \
122
+ data-bin/wmt18_en_de \
123
+ --source-lang de --target-lang en \
124
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
125
+ --dropout 0.3 --weight-decay 0.0 \
126
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
127
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
128
+ --lr 0.001 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
129
+ --max-tokens 3584 --update-freq 16 \
130
+ --max-update 30000 \
131
+ --save-dir $CHECKPOINT_DIR
132
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
133
+ # different number of GPUs.
134
+ ```
135
+
136
+ Let's evaluate the back-translation (BT) model to make sure it is well trained:
137
+ ```bash
138
+ bash examples/backtranslation/sacrebleu.sh \
139
+ wmt17 \
140
+ de-en \
141
+ data-bin/wmt18_en_de \
142
+ data-bin/wmt18_en_de/code \
143
+ $CHECKPOINT_DIR/checkpoint_best.py
144
+ # BLEU+case.mixed+lang.de-en+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 34.9 66.9/41.8/28.5/19.9 (BP = 0.983 ratio = 0.984 hyp_len = 63342 ref_len = 64399)
145
+ # compare to the best system from WMT'17 which scored 35.1: http://matrix.statmt.org/matrix/systems_list/1868
146
+ ```
147
+
148
+ Next prepare the monolingual data:
149
+ ```bash
150
+ # Download and prepare the monolingual data
151
+ # By default the script samples 25M monolingual sentences, which after
152
+ # deduplication should be just over 24M sentences. These are split into 25
153
+ # shards, each with 1M sentences (except for the last shard).
154
+ cd examples/backtranslation/
155
+ bash prepare-de-monolingual.sh
156
+ cd ../..
157
+
158
+ # Binarize each shard of the monolingual data
159
+ TEXT=examples/backtranslation/wmt18_de_mono
160
+ for SHARD in $(seq -f "%02g" 0 24); do \
161
+ fairseq-preprocess \
162
+ --only-source \
163
+ --source-lang de --target-lang en \
164
+ --joined-dictionary \
165
+ --srcdict data-bin/wmt18_en_de/dict.de.txt \
166
+ --testpref $TEXT/bpe.monolingual.dedup.${SHARD} \
167
+ --destdir data-bin/wmt18_de_mono/shard${SHARD} \
168
+ --workers 20; \
169
+ cp data-bin/wmt18_en_de/dict.en.txt data-bin/wmt18_de_mono/shard${SHARD}/; \
170
+ done
171
+ ```
172
+
173
+ Now we're ready to perform back-translation over the monolingual data. The
174
+ following command generates via sampling, but it's possible to use greedy
175
+ decoding (`--beam 1`), beam search (`--beam 5`),
176
+ top-k sampling (`--sampling --beam 1 --sampling-topk 10`), etc.:
177
+ ```bash
178
+ mkdir backtranslation_output
179
+ for SHARD in $(seq -f "%02g" 0 24); do \
180
+ fairseq-generate --fp16 \
181
+ data-bin/wmt18_de_mono/shard${SHARD} \
182
+ --path $CHECKPOINT_DIR/checkpoint_best.pt \
183
+ --skip-invalid-size-inputs-valid-test \
184
+ --max-tokens 4096 \
185
+ --sampling --beam 1 \
186
+ > backtranslation_output/sampling.shard${SHARD}.out; \
187
+ done
188
+ ```
189
+
190
+ After BT, use the `extract_bt_data.py` script to re-combine the shards, extract
191
+ the back-translations and apply length ratio filters:
192
+ ```bash
193
+ python examples/backtranslation/extract_bt_data.py \
194
+ --minlen 1 --maxlen 250 --ratio 1.5 \
195
+ --output backtranslation_output/bt_data --srclang en --tgtlang de \
196
+ backtranslation_output/sampling.shard*.out
197
+
198
+ # Ensure lengths are the same:
199
+ # wc -l backtranslation_output/bt_data.{en,de}
200
+ # 21795614 backtranslation_output/bt_data.en
201
+ # 21795614 backtranslation_output/bt_data.de
202
+ # 43591228 total
203
+ ```
204
+
205
+ Binarize the filtered BT data and combine it with the parallel data:
206
+ ```bash
207
+ TEXT=backtranslation_output
208
+ fairseq-preprocess \
209
+ --source-lang en --target-lang de \
210
+ --joined-dictionary \
211
+ --srcdict data-bin/wmt18_en_de/dict.en.txt \
212
+ --trainpref $TEXT/bt_data \
213
+ --destdir data-bin/wmt18_en_de_bt \
214
+ --workers 20
215
+
216
+ # We want to train on the combined data, so we'll symlink the parallel + BT data
217
+ # in the wmt18_en_de_para_plus_bt directory. We link the parallel data as "train"
218
+ # and the BT data as "train1", so that fairseq will combine them automatically
219
+ # and so that we can use the `--upsample-primary` option to upsample the
220
+ # parallel data (if desired).
221
+ PARA_DATA=$(readlink -f data-bin/wmt18_en_de)
222
+ BT_DATA=$(readlink -f data-bin/wmt18_en_de_bt)
223
+ COMB_DATA=data-bin/wmt18_en_de_para_plus_bt
224
+ mkdir -p $COMB_DATA
225
+ for LANG in en de; do \
226
+ ln -s ${PARA_DATA}/dict.$LANG.txt ${COMB_DATA}/dict.$LANG.txt; \
227
+ for EXT in bin idx; do \
228
+ ln -s ${PARA_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train.en-de.$LANG.$EXT; \
229
+ ln -s ${BT_DATA}/train.en-de.$LANG.$EXT ${COMB_DATA}/train1.en-de.$LANG.$EXT; \
230
+ ln -s ${PARA_DATA}/valid.en-de.$LANG.$EXT ${COMB_DATA}/valid.en-de.$LANG.$EXT; \
231
+ ln -s ${PARA_DATA}/test.en-de.$LANG.$EXT ${COMB_DATA}/test.en-de.$LANG.$EXT; \
232
+ done; \
233
+ done
234
+ ```
235
+
236
+
237
+ #### 3. Train an English-German model over the combined parallel + BT data
238
+
239
+ Finally we can train a model over the parallel + BT data:
240
+ ```bash
241
+ CHECKPOINT_DIR=checkpoints_en_de_parallel_plus_bt
242
+ fairseq-train --fp16 \
243
+ data-bin/wmt18_en_de_para_plus_bt \
244
+ --upsample-primary 16 \
245
+ --source-lang en --target-lang de \
246
+ --arch transformer_wmt_en_de_big --share-all-embeddings \
247
+ --dropout 0.3 --weight-decay 0.0 \
248
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
249
+ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
250
+ --lr 0.0007 --lr-scheduler inverse_sqrt --warmup-updates 4000 \
251
+ --max-tokens 3584 --update-freq 16 \
252
+ --max-update 100000 \
253
+ --save-dir $CHECKPOINT_DIR
254
+ # Note: the above command assumes 8 GPUs. Adjust `--update-freq` if you have a
255
+ # different number of GPUs.
256
+ ```
257
+
258
+ Average the last 10 checkpoints:
259
+ ```bash
260
+ python scripts/average_checkpoints.py \
261
+ --inputs $CHECKPOINT_DIR \
262
+ --num-epoch-checkpoints 10 \
263
+ --output $CHECKPOINT_DIR/checkpoint.avg10.pt
264
+ ```
265
+
266
+ Evaluate BLEU:
267
+ ```bash
268
+ # tokenized BLEU on newstest2017:
269
+ bash examples/backtranslation/tokenized_bleu.sh \
270
+ wmt17 \
271
+ en-de \
272
+ data-bin/wmt18_en_de \
273
+ data-bin/wmt18_en_de/code \
274
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
275
+ # BLEU4 = 32.35, 64.4/38.9/26.2/18.3 (BP=0.977, ratio=0.977, syslen=60729, reflen=62152)
276
+ # compare to 32.35 in Table 1, which is also for tokenized BLEU
277
+
278
+ # generally it's better to report (detokenized) sacrebleu:
279
+ bash examples/backtranslation/sacrebleu.sh \
280
+ wmt17 \
281
+ en-de \
282
+ data-bin/wmt18_en_de \
283
+ data-bin/wmt18_en_de/code \
284
+ $CHECKPOINT_DIR/checkpoint.avg10.pt
285
+ # BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt17+tok.13a+version.1.4.3 = 31.5 64.3/38.2/25.6/17.6 (BP = 0.971 ratio = 0.971 hyp_len = 59515 ref_len = 61287)
286
+ ```
287
+
288
+
289
+ ## Citation
290
+ ```bibtex
291
+ @inproceedings{edunov2018backtranslation,
292
+ title = {Understanding Back-Translation at Scale},
293
+ author = {Edunov, Sergey and Ott, Myle and Auli, Michael and Grangier, David},
294
+ booktitle = {Conference of the Association for Computational Linguistics (ACL)},
295
+ year = 2018,
296
+ }
297
+ ```
fairseq-0.10.2/examples/backtranslation/deduplicate_lines.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+ import hashlib
10
+ import sys
11
+ from multiprocessing import Pool
12
+
13
+
14
+ def get_hashes_and_lines(raw_line):
15
+ hash = hashlib.md5(raw_line).hexdigest()
16
+ return hash, raw_line
17
+
18
+
19
+ def main():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--workers", type=int, default=10)
22
+ parser.add_argument("files", nargs="*", help="input files")
23
+ args = parser.parse_args()
24
+
25
+ seen = set()
26
+ with fileinput.input(args.files, mode="rb") as h:
27
+ pool = Pool(args.workers)
28
+ results = pool.imap_unordered(get_hashes_and_lines, h, 1000)
29
+ for i, (hash, raw_line) in enumerate(results):
30
+ if hash not in seen:
31
+ seen.add(hash)
32
+ sys.stdout.buffer.write(raw_line)
33
+ if i % 1000000 == 0:
34
+ print(i, file=sys.stderr, end="", flush=True)
35
+ elif i % 100000 == 0:
36
+ print(".", file=sys.stderr, end="", flush=True)
37
+ print(file=sys.stderr, flush=True)
38
+
39
+
40
+ if __name__ == "__main__":
41
+ main()
fairseq-0.10.2/examples/backtranslation/extract_bt_data.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import fileinput
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser(
15
+ description=(
16
+ "Extract back-translations from the stdout of fairseq-generate. "
17
+ "If there are multiply hypotheses for a source, we only keep the first one. "
18
+ )
19
+ )
20
+ parser.add_argument("--output", required=True, help="output prefix")
21
+ parser.add_argument(
22
+ "--srclang", required=True, help="source language (extracted from H-* lines)"
23
+ )
24
+ parser.add_argument(
25
+ "--tgtlang", required=True, help="target language (extracted from S-* lines)"
26
+ )
27
+ parser.add_argument("--minlen", type=int, help="min length filter")
28
+ parser.add_argument("--maxlen", type=int, help="max length filter")
29
+ parser.add_argument("--ratio", type=float, help="ratio filter")
30
+ parser.add_argument("files", nargs="*", help="input files")
31
+ args = parser.parse_args()
32
+
33
+ def validate(src, tgt):
34
+ srclen = len(src.split(" ")) if src != "" else 0
35
+ tgtlen = len(tgt.split(" ")) if tgt != "" else 0
36
+ if (
37
+ (args.minlen is not None and (srclen < args.minlen or tgtlen < args.minlen))
38
+ or (
39
+ args.maxlen is not None
40
+ and (srclen > args.maxlen or tgtlen > args.maxlen)
41
+ )
42
+ or (
43
+ args.ratio is not None
44
+ and (max(srclen, tgtlen) / float(min(srclen, tgtlen)) > args.ratio)
45
+ )
46
+ ):
47
+ return False
48
+ return True
49
+
50
+ def safe_index(toks, index, default):
51
+ try:
52
+ return toks[index]
53
+ except IndexError:
54
+ return default
55
+
56
+ with open(args.output + "." + args.srclang, "w") as src_h, open(
57
+ args.output + "." + args.tgtlang, "w"
58
+ ) as tgt_h:
59
+ for line in tqdm(fileinput.input(args.files)):
60
+ if line.startswith("S-"):
61
+ tgt = safe_index(line.rstrip().split("\t"), 1, "")
62
+ elif line.startswith("H-"):
63
+ if tgt is not None:
64
+ src = safe_index(line.rstrip().split("\t"), 2, "")
65
+ if validate(src, tgt):
66
+ print(src, file=src_h)
67
+ print(tgt, file=tgt_h)
68
+ tgt = None
69
+
70
+
71
+ if __name__ == "__main__":
72
+ main()
fairseq-0.10.2/examples/backtranslation/prepare-de-monolingual.sh ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ SCRIPTS=mosesdecoder/scripts
4
+ TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
5
+ NORM_PUNC=$SCRIPTS/tokenizer/normalize-punctuation.perl
6
+ REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
7
+ BPEROOT=subword-nmt/subword_nmt
8
+
9
+
10
+ BPE_CODE=wmt18_en_de/code
11
+ SUBSAMPLE_SIZE=25000000
12
+ LANG=de
13
+
14
+
15
+ OUTDIR=wmt18_${LANG}_mono
16
+ orig=orig
17
+ tmp=$OUTDIR/tmp
18
+ mkdir -p $OUTDIR $tmp
19
+
20
+
21
+ URLS=(
22
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2007.de.shuffled.gz"
23
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2008.de.shuffled.gz"
24
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2009.de.shuffled.gz"
25
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2010.de.shuffled.gz"
26
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2011.de.shuffled.gz"
27
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2012.de.shuffled.gz"
28
+ "http://www.statmt.org/wmt14/training-monolingual-news-crawl/news.2013.de.shuffled.gz"
29
+ "http://www.statmt.org/wmt15/training-monolingual-news-crawl-v2/news.2014.de.shuffled.v2.gz"
30
+ "http://data.statmt.org/wmt16/translation-task/news.2015.de.shuffled.gz"
31
+ "http://data.statmt.org/wmt17/translation-task/news.2016.de.shuffled.gz"
32
+ "http://data.statmt.org/wmt18/translation-task/news.2017.de.shuffled.deduped.gz"
33
+ )
34
+ FILES=(
35
+ "news.2007.de.shuffled.gz"
36
+ "news.2008.de.shuffled.gz"
37
+ "news.2009.de.shuffled.gz"
38
+ "news.2010.de.shuffled.gz"
39
+ "news.2011.de.shuffled.gz"
40
+ "news.2012.de.shuffled.gz"
41
+ "news.2013.de.shuffled.gz"
42
+ "news.2014.de.shuffled.v2.gz"
43
+ "news.2015.de.shuffled.gz"
44
+ "news.2016.de.shuffled.gz"
45
+ "news.2017.de.shuffled.deduped.gz"
46
+ )
47
+
48
+
49
+ cd $orig
50
+ for ((i=0;i<${#URLS[@]};++i)); do
51
+ file=${FILES[i]}
52
+ if [ -f $file ]; then
53
+ echo "$file already exists, skipping download"
54
+ else
55
+ url=${URLS[i]}
56
+ wget "$url"
57
+ fi
58
+ done
59
+ cd ..
60
+
61
+
62
+ if [ -f $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
63
+ echo "found monolingual sample, skipping shuffle/sample/tokenize"
64
+ else
65
+ gzip -c -d -k $(for FILE in "${FILES[@]}"; do echo $orig/$FILE; done) \
66
+ | shuf -n $SUBSAMPLE_SIZE \
67
+ | perl $NORM_PUNC $LANG \
68
+ | perl $REM_NON_PRINT_CHAR \
69
+ | perl $TOKENIZER -threads 8 -a -l $LANG \
70
+ > $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG}
71
+ fi
72
+
73
+
74
+ if [ -f $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} ]; then
75
+ echo "found BPE monolingual sample, skipping BPE step"
76
+ else
77
+ python $BPEROOT/apply_bpe.py -c $BPE_CODE \
78
+ < $tmp/monolingual.${SUBSAMPLE_SIZE}.${LANG} \
79
+ > $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG}
80
+ fi
81
+
82
+
83
+ if [ -f $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} ]; then
84
+ echo "found deduplicated monolingual sample, skipping deduplication step"
85
+ else
86
+ python deduplicate_lines.py $tmp/bpe.monolingual.${SUBSAMPLE_SIZE}.${LANG} \
87
+ > $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG}
88
+ fi
89
+
90
+
91
+ if [ -f $OUTDIR/bpe.monolingual.dedup.00.de ]; then
92
+ echo "found sharded data, skipping sharding step"
93
+ else
94
+ split --lines 1000000 --numeric-suffixes \
95
+ --additional-suffix .${LANG} \
96
+ $tmp/bpe.monolingual.dedup.${SUBSAMPLE_SIZE}.${LANG} \
97
+ $OUTDIR/bpe.monolingual.dedup.
98
+ fi
fairseq-0.10.2/examples/backtranslation/sacrebleu.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src \
30
+ | sacremoses tokenize -a -l $SRCLANG -q \
31
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
32
+ | fairseq-interactive $DATABIN --path $MODEL \
33
+ -s $SRCLANG -t $TGTLANG \
34
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
35
+ | grep ^H- | cut -f 3- \
36
+ | sacremoses detokenize -l $TGTLANG -q \
37
+ | sacrebleu -t $DATASET -l $LANGPAIR
fairseq-0.10.2/examples/backtranslation/tokenized_bleu.sh ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ if [ $# -ne 5 ]; then
4
+ echo "usage: $0 [dataset=wmt14/full] [langpair=en-de] [databin] [bpecode] [model]"
5
+ exit
6
+ fi
7
+
8
+
9
+ DATASET=$1
10
+ LANGPAIR=$2
11
+ DATABIN=$3
12
+ BPECODE=$4
13
+ MODEL=$5
14
+
15
+ SRCLANG=$(echo $LANGPAIR | cut -d '-' -f 1)
16
+ TGTLANG=$(echo $LANGPAIR | cut -d '-' -f 2)
17
+
18
+
19
+ BPEROOT=examples/backtranslation/subword-nmt/subword_nmt
20
+ if [ ! -e $BPEROOT ]; then
21
+ BPEROOT=subword-nmt/subword_nmt
22
+ if [ ! -e $BPEROOT ]; then
23
+ echo 'Cloning Subword NMT repository (for BPE pre-processing)...'
24
+ git clone https://github.com/rsennrich/subword-nmt.git
25
+ fi
26
+ fi
27
+
28
+
29
+ TMP_REF=$(mktemp)
30
+
31
+ sacrebleu -t $DATASET -l $LANGPAIR --echo ref -q \
32
+ | sacremoses normalize -l $TGTLANG -q \
33
+ | sacremoses tokenize -a -l $TGTLANG -q \
34
+ > $TMP_REF
35
+
36
+ sacrebleu -t $DATASET -l $LANGPAIR --echo src -q \
37
+ | sacremoses normalize -l $SRCLANG -q \
38
+ | sacremoses tokenize -a -l $SRCLANG -q \
39
+ | python $BPEROOT/apply_bpe.py -c $BPECODE \
40
+ | fairseq-interactive $DATABIN --path $MODEL \
41
+ -s $SRCLANG -t $TGTLANG \
42
+ --beam 5 --remove-bpe --buffer-size 1024 --max-tokens 8000 \
43
+ | grep ^H- | cut -f 3- \
44
+ | fairseq-score --ref $TMP_REF
45
+
46
+ rm -f $TMP_REF
fairseq-0.10.2/examples/bart/README.glue.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on GLUE tasks
2
+
3
+ ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4
+ ```bash
5
+ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6
+ python download_glue_data.py --data_dir glue_data --tasks all
7
+ ```
8
+
9
+ ### 2) Preprocess GLUE task data (same as RoBERTa):
10
+ ```bash
11
+ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12
+ ```
13
+ `glue_task_name` is one of the following:
14
+ `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15
+ Use `ALL` for preprocessing all the glue tasks.
16
+
17
+ ### 3) Fine-tuning on GLUE task:
18
+ Example fine-tuning cmd for `RTE` task
19
+ ```bash
20
+ TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21
+ WARMUP_UPDATES=61 # 6 percent of the number of updates
22
+ LR=1e-05 # Peak LR for polynomial LR scheduler.
23
+ NUM_CLASSES=2
24
+ MAX_SENTENCES=16 # Batch size.
25
+ BART_PATH=/path/to/bart/model.pt
26
+
27
+ CUDA_VISIBLE_DEVICES=0,1 fairseq-train RTE-bin/ \
28
+ --restore-file $BART_PATH \
29
+ --batch-size $MAX_SENTENCES \
30
+ --max-tokens 4400 \
31
+ --task sentence_prediction \
32
+ --add-prev-output-tokens \
33
+ --layernorm-embedding \
34
+ --share-all-embeddings \
35
+ --share-decoder-input-output-embed \
36
+ --reset-optimizer --reset-dataloader --reset-meters \
37
+ --required-batch-size-multiple 1 \
38
+ --init-token 0 \
39
+ --arch bart_large \
40
+ --criterion sentence_prediction \
41
+ --num-classes $NUM_CLASSES \
42
+ --dropout 0.1 --attention-dropout 0.1 \
43
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-08 \
44
+ --clip-norm 0.0 \
45
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
46
+ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
47
+ --max-epoch 10 \
48
+ --find-unused-parameters \
49
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
50
+ ```
51
+
52
+ For each of the GLUE task, you will need to use following cmd-line arguments:
53
+
54
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
55
+ ---|---|---|---|---|---|---|---|---
56
+ `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
57
+ `--lr` | 5e-6 | 1e-5 | 1e-5 | 1e-5 | 5e-6 | 2e-5 | 2e-5 | 2e-5
58
+ `bsz` | 128 | 32 | 32 | 32 | 128 | 64 | 64 | 32
59
+ `--total-num-update` | 30968 | 33112 | 113272 | 1018 | 5233 | 1148 | 1334 | 1799
60
+ `--warmup-updates` | 1858 | 1986 | 6796 | 61 | 314 | 68 | 80 | 107
61
+
62
+ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
63
+
64
+ **Note:**
65
+
66
+ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=32/64/128` depending on the task.
67
+
68
+ b) Above cmd-args and hyperparams are tested on Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
69
+
70
+ ### Inference on GLUE task
71
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
72
+
73
+ ```python
74
+ from fairseq.models.bart import BARTModel
75
+
76
+ bart = BARTModel.from_pretrained(
77
+ 'checkpoints/',
78
+ checkpoint_file='checkpoint_best.pt',
79
+ data_name_or_path='RTE-bin'
80
+ )
81
+
82
+ label_fn = lambda label: bart.task.label_dictionary.string(
83
+ [label + bart.task.label_dictionary.nspecial]
84
+ )
85
+ ncorrect, nsamples = 0, 0
86
+ bart.cuda()
87
+ bart.eval()
88
+ with open('glue_data/RTE/dev.tsv') as fin:
89
+ fin.readline()
90
+ for index, line in enumerate(fin):
91
+ tokens = line.strip().split('\t')
92
+ sent1, sent2, target = tokens[1], tokens[2], tokens[3]
93
+ tokens = bart.encode(sent1, sent2)
94
+ prediction = bart.predict('sentence_classification_head', tokens).argmax().item()
95
+ prediction_label = label_fn(prediction)
96
+ ncorrect += int(prediction_label == target)
97
+ nsamples += 1
98
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
99
+ ```
fairseq-0.10.2/examples/bart/README.md ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BART: Denoising Sequence-to-Sequence Pre-training for Natural Language Generation, Translation, and Comprehension
2
+
3
+ [https://arxiv.org/pdf/1910.13461.pdf]
4
+
5
+ ## Introduction
6
+
7
+ BART is sequence-to-sequence model trained with denoising as pretraining objective. We show that this pretraining objective is more generic and show that we can match [RoBERTa](../roberta) results on SQuAD and GLUE and gain state-of-the-art results on summarization (XSum, CNN dataset), long form generative question answering (ELI5) and dialog response genration (ConvAI2). See the associated paper for more details.
8
+
9
+ ## Pre-trained models
10
+
11
+ Model | Description | # params | Download
12
+ ---|---|---|---
13
+ `bart.base` | BART model with 6 encoder and decoder layers | 140M | [bart.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz)
14
+ `bart.large` | BART model with 12 encoder and decoder layers | 400M | [bart.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz)
15
+ `bart.large.mnli` | `bart.large` finetuned on `MNLI` | 400M | [bart.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz)
16
+ `bart.large.cnn` | `bart.large` finetuned on `CNN-DM` | 400M | [bart.large.cnn.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz)
17
+ `bart.large.xsum` | `bart.large` finetuned on `Xsum` | 400M | [bart.large.xsum.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz)
18
+
19
+ ## Results
20
+
21
+ **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
22
+ _(dev set, single model, single-task finetuning)_
23
+
24
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
25
+ ---|---|---|---|---|---|---|---|---
26
+ `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
27
+ `bart.large` | 89.9 | 94.9 | 92.5 | 87.0 | 96.6 | 90.4 | 62.8 | 91.2
28
+
29
+ **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
30
+ _(dev set, no additional data used)_
31
+
32
+ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
33
+ ---|---|---
34
+ `roberta.large` | 88.9/94.6 | 86.5/89.4
35
+ `bart.large` | 88.8/94.6 | 86.1/89.2
36
+
37
+ **[CNN/Daily Mail](http://nlpprogress.com/english/summarization.html)**
38
+ _(test set, no additional data used)_
39
+
40
+ Model | R1 | R2 | RL
41
+ ---|---|---|---
42
+ `BERTSUMEXTABS` | 42.13 | 19.60 | 39.18
43
+ `bart.large` | 44.16 | 21.28 | 40.90
44
+
45
+ ## Example usage
46
+
47
+ ##### Load BART from torch.hub (PyTorch >= 1.1):
48
+ ```python
49
+ import torch
50
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large')
51
+ bart.eval() # disable dropout (or leave in train mode to finetune)
52
+ ```
53
+
54
+ ##### Load BART (for PyTorch 1.0 or custom models):
55
+ ```python
56
+ # Download bart.large model
57
+ wget https://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz
58
+ tar -xzvf bart.large.tar.gz
59
+
60
+ # Load the model in fairseq
61
+ from fairseq.models.bart import BARTModel
62
+ bart = BARTModel.from_pretrained('/path/to/bart.large', checkpoint_file='model.pt')
63
+ bart.eval() # disable dropout (or leave in train mode to finetune)
64
+ ```
65
+
66
+ ##### Apply Byte-Pair Encoding (BPE) to input text:
67
+ ```python
68
+ tokens = bart.encode('Hello world!')
69
+ assert tokens.tolist() == [0, 31414, 232, 328, 2]
70
+ bart.decode(tokens) # 'Hello world!'
71
+ ```
72
+
73
+ ##### Extract features from BART:
74
+ ```python
75
+ # Extract the last layer's features
76
+ last_layer_features = bart.extract_features(tokens)
77
+ assert last_layer_features.size() == torch.Size([1, 5, 1024])
78
+
79
+ # Extract all layer's features from decoder (layer 0 is the embedding layer)
80
+ all_layers = bart.extract_features(tokens, return_all_hiddens=True)
81
+ assert len(all_layers) == 13
82
+ assert torch.all(all_layers[-1] == last_layer_features)
83
+ ```
84
+
85
+ ##### Use BART for sentence-pair classification tasks:
86
+ ```python
87
+ # Download BART already finetuned for MNLI
88
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
89
+ bart.eval() # disable dropout for evaluation
90
+
91
+ # Encode a pair of sentences and make a prediction
92
+ tokens = bart.encode('BART is a seq2seq model.', 'BART is not sequence to sequence.')
93
+ bart.predict('mnli', tokens).argmax() # 0: contradiction
94
+
95
+ # Encode another pair of sentences
96
+ tokens = bart.encode('BART is denoising autoencoder.', 'BART is version of autoencoder.')
97
+ bart.predict('mnli', tokens).argmax() # 2: entailment
98
+ ```
99
+
100
+ ##### Register a new (randomly initialized) classification head:
101
+ ```python
102
+ bart.register_classification_head('new_task', num_classes=3)
103
+ logprobs = bart.predict('new_task', tokens)
104
+ ```
105
+
106
+ ##### Batched prediction:
107
+ ```python
108
+ import torch
109
+ from fairseq.data.data_utils import collate_tokens
110
+
111
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.mnli')
112
+ bart.eval()
113
+
114
+ batch_of_pairs = [
115
+ ['BART is a seq2seq model.', 'BART is not sequence to sequence.'],
116
+ ['BART is denoising autoencoder.', 'BART is version of autoencoder.'],
117
+ ]
118
+
119
+ batch = collate_tokens(
120
+ [bart.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
121
+ )
122
+
123
+ logprobs = bart.predict('mnli', batch)
124
+ print(logprobs.argmax(dim=1))
125
+ # tensor([0, 2])
126
+ ```
127
+
128
+ ##### Using the GPU:
129
+ ```python
130
+ bart.cuda()
131
+ bart.predict('new_task', tokens)
132
+ ```
133
+
134
+ #### Evaluating the `bart.large.mnli` model:
135
+
136
+ Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
137
+ ```python
138
+ label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
139
+ ncorrect, nsamples = 0, 0
140
+ bart.cuda()
141
+ bart.eval()
142
+ with open('glue_data/MNLI/dev_matched.tsv') as fin:
143
+ fin.readline()
144
+ for index, line in enumerate(fin):
145
+ tokens = line.strip().split('\t')
146
+ sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
147
+ tokens = bart.encode(sent1, sent2)
148
+ prediction = bart.predict('mnli', tokens).argmax().item()
149
+ prediction_label = label_map[prediction]
150
+ ncorrect += int(prediction_label == target)
151
+ nsamples += 1
152
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
153
+ # Expected output: 0.9010
154
+ ```
155
+
156
+ #### Evaluating the `bart.large.cnn` model:
157
+ Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files such that `test.source` and `test.target` has one line for each non-tokenized sample.
158
+
159
+ ```python
160
+ bart = torch.hub.load('pytorch/fairseq', 'bart.large.cnn')
161
+ bart.cuda()
162
+ bart.eval()
163
+ bart.half()
164
+ count = 1
165
+ bsz = 32
166
+ with open('test.source') as source, open('test.hypo', 'w') as fout:
167
+ sline = source.readline().strip()
168
+ slines = [sline]
169
+ for sline in source:
170
+ if count % bsz == 0:
171
+ with torch.no_grad():
172
+ hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
173
+
174
+ for hypothesis in hypotheses_batch:
175
+ fout.write(hypothesis + '\n')
176
+ fout.flush()
177
+ slines = []
178
+
179
+ slines.append(sline.strip())
180
+ count += 1
181
+ if slines != []:
182
+ hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
183
+ for hypothesis in hypotheses_batch:
184
+ fout.write(hypothesis + '\n')
185
+ fout.flush()
186
+ ```
187
+
188
+ Install `files2rouge` from [here](https://github.com/pltrdy/files2rouge).
189
+
190
+ ```bash
191
+ export CLASSPATH=/path/to/stanford-corenlp-full-2016-10-31/stanford-corenlp-3.7.0.jar
192
+
193
+ # Tokenize hypothesis and target files.
194
+ cat test.hypo | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.tokenized
195
+ cat test.target | java edu.stanford.nlp.process.PTBTokenizer -ioFileList -preserveLines > test.hypo.target
196
+ files2rouge test.hypo.tokenized test.hypo.target
197
+ # Expected output: (ROUGE-2 Average_F: 0.21238)
198
+ ```
199
+
200
+
201
+ ## Finetuning
202
+
203
+ - [Finetuning on GLUE](README.glue.md)
204
+ - [Finetuning on CNN-DM](README.summarization.md)
205
+
206
+ ## Citation
207
+
208
+ ```bibtex
209
+ @article{lewis2019bart,
210
+ title = {BART: Denoising Sequence-to-Sequence Pre-training for Natural
211
+ Language Generation, Translation, and Comprehension},
212
+ author = {Mike Lewis and Yinhan Liu and Naman Goyal and Marjan Ghazvininejad and
213
+ Abdelrahman Mohamed and Omer Levy and Veselin Stoyanov
214
+ and Luke Zettlemoyer },
215
+ journal={arXiv preprint arXiv:1910.13461},
216
+ year = {2019},
217
+ }
218
+ ```
fairseq-0.10.2/examples/bart/README.summarization.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Fine-tuning BART on CNN-Dailymail summarization task
2
+
3
+ ### 1) Download the CNN and Daily Mail data and preprocess it into data files with non-tokenized cased samples.
4
+
5
+ Follow the instructions [here](https://github.com/abisee/cnn-dailymail) to download the original CNN and Daily Mail datasets. To preprocess the data, refer to the pointers in [this issue](https://github.com/pytorch/fairseq/issues/1391) or check out the code [here](https://github.com/artmatsak/cnn-dailymail).
6
+
7
+ Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to download the original Extreme Summarization datasets, or check out the code [here](https://github.com/EdinburghNLP/XSum/tree/master/XSum-Dataset), Please keep the raw dataset and make sure no tokenization nor BPE on the dataset.
8
+
9
+ ### 2) BPE preprocess:
10
+
11
+ ```bash
12
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
13
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
14
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
15
+
16
+ TASK=cnn_dm
17
+ for SPLIT in train val
18
+ do
19
+ for LANG in source target
20
+ do
21
+ python -m examples.roberta.multiprocessing_bpe_encoder \
22
+ --encoder-json encoder.json \
23
+ --vocab-bpe vocab.bpe \
24
+ --inputs "$TASK/$SPLIT.$LANG" \
25
+ --outputs "$TASK/$SPLIT.bpe.$LANG" \
26
+ --workers 60 \
27
+ --keep-empty;
28
+ done
29
+ done
30
+ ```
31
+
32
+ ### 3) Binarize dataset:
33
+ ```bash
34
+ fairseq-preprocess \
35
+ --source-lang "source" \
36
+ --target-lang "target" \
37
+ --trainpref "${TASK}/train.bpe" \
38
+ --validpref "${TASK}/val.bpe" \
39
+ --destdir "${TASK}-bin/" \
40
+ --workers 60 \
41
+ --srcdict dict.txt \
42
+ --tgtdict dict.txt;
43
+ ```
44
+
45
+ ### 4) Fine-tuning on CNN-DM summarization task:
46
+ Example fine-tuning CNN-DM
47
+ ```bash
48
+ TOTAL_NUM_UPDATES=20000
49
+ WARMUP_UPDATES=500
50
+ LR=3e-05
51
+ MAX_TOKENS=2048
52
+ UPDATE_FREQ=4
53
+ BART_PATH=/path/to/bart/model.pt
54
+
55
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train cnn_dm-bin \
56
+ --restore-file $BART_PATH \
57
+ --max-tokens $MAX_TOKENS \
58
+ --task translation \
59
+ --source-lang source --target-lang target \
60
+ --truncate-source \
61
+ --layernorm-embedding \
62
+ --share-all-embeddings \
63
+ --share-decoder-input-output-embed \
64
+ --reset-optimizer --reset-dataloader --reset-meters \
65
+ --required-batch-size-multiple 1 \
66
+ --arch bart_large \
67
+ --criterion label_smoothed_cross_entropy \
68
+ --label-smoothing 0.1 \
69
+ --dropout 0.1 --attention-dropout 0.1 \
70
+ --weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
71
+ --clip-norm 0.1 \
72
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
73
+ --fp16 --update-freq $UPDATE_FREQ \
74
+ --skip-invalid-size-inputs-valid-test \
75
+ --find-unused-parameters;
76
+ ```
77
+ Above is expected to run on `1` node with `8 32gb-V100`.
78
+ Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`.
79
+
80
+ Use TOTAL_NUM_UPDATES=15000 UPDATE_FREQ=2 for Xsum task
81
+
82
+ ### Inference for CNN-DM test data using above trained checkpoint.
83
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
84
+
85
+ ```python
86
+ import torch
87
+ from fairseq.models.bart import BARTModel
88
+
89
+ bart = BARTModel.from_pretrained(
90
+ 'checkpoints/',
91
+ checkpoint_file='checkpoint_best.pt',
92
+ data_name_or_path='cnn_dm-bin'
93
+ )
94
+
95
+ bart.cuda()
96
+ bart.eval()
97
+ bart.half()
98
+ count = 1
99
+ bsz = 32
100
+ with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout:
101
+ sline = source.readline().strip()
102
+ slines = [sline]
103
+ for sline in source:
104
+ if count % bsz == 0:
105
+ with torch.no_grad():
106
+ hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
107
+
108
+ for hypothesis in hypotheses_batch:
109
+ fout.write(hypothesis + '\n')
110
+ fout.flush()
111
+ slines = []
112
+
113
+ slines.append(sline.strip())
114
+ count += 1
115
+ if slines != []:
116
+ hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3)
117
+ for hypothesis in hypotheses_batch:
118
+ fout.write(hypothesis + '\n')
119
+ fout.flush()
120
+ ```
121
+ Use beam=6, lenpen=1.0, max_len_b=60, min_len=10 for Xsum Generation
fairseq-0.10.2/examples/criss/download_and_preprocess_flores_test.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ SPM_ENCODE=flores/scripts/spm_encode.py
9
+ DATA=data_tmp
10
+ SPM_MODEL=criss_checkpoints/sentence.bpe.model
11
+ DICT=criss_checkpoints/dict.txt
12
+
13
+ download_data() {
14
+ CORPORA=$1
15
+ URL=$2
16
+
17
+ if [ -f $CORPORA ]; then
18
+ echo "$CORPORA already exists, skipping download"
19
+ else
20
+ echo "Downloading $URL"
21
+ wget $URL -O $CORPORA --no-check-certificate || rm -f $CORPORA
22
+ if [ -f $CORPORA ]; then
23
+ echo "$URL successfully downloaded."
24
+ else
25
+ echo "$URL not successfully downloaded."
26
+ rm -f $CORPORA
27
+ fi
28
+ fi
29
+ }
30
+
31
+ if [[ -f flores ]]; then
32
+ echo "flores already cloned"
33
+ else
34
+ git clone https://github.com/facebookresearch/flores
35
+ fi
36
+
37
+ mkdir -p $DATA
38
+ download_data $DATA/wikipedia_en_ne_si_test_sets.tgz "https://github.com/facebookresearch/flores/raw/master/data/wikipedia_en_ne_si_test_sets.tgz"
39
+ pushd $DATA
40
+ pwd
41
+ tar -vxf wikipedia_en_ne_si_test_sets.tgz
42
+ popd
43
+
44
+
45
+ for lang in ne_NP si_LK; do
46
+ datadir=$DATA/${lang}-en_XX-flores
47
+ rm -rf $datadir
48
+ mkdir -p $datadir
49
+ TEST_PREFIX=$DATA/wikipedia_en_ne_si_test_sets/wikipedia.test
50
+ python $SPM_ENCODE \
51
+ --model ${SPM_MODEL} \
52
+ --output_format=piece \
53
+ --inputs ${TEST_PREFIX}.${lang:0:2}-en.${lang:0:2} ${TEST_PREFIX}.${lang:0:2}-en.en \
54
+ --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
55
+
56
+ # binarize data
57
+ fairseq-preprocess \
58
+ --source-lang ${lang} --target-lang en_XX \
59
+ --testpref $datadir/test.bpe.${lang}-en_XX \
60
+ --destdir $datadir \
61
+ --srcdict ${DICT} \
62
+ --joined-dictionary \
63
+ --workers 4
64
+ done
fairseq-0.10.2/examples/criss/download_and_preprocess_tatoeba.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ SPM_ENCODE=flores/scripts/spm_encode.py
9
+ DATA=data_tmp
10
+ SPM_MODEL=criss_checkpoints/sentence.bpe.model
11
+ DICT=criss_checkpoints/dict.txt
12
+
13
+ git clone https://github.com/facebookresearch/LASER
14
+ mkdir -p data_tmp
15
+ declare -A lang_tatoeba_map=( ["ar_AR"]="ara" ["de_DE"]="deu" ["es_XX"]="spa" ["et_EE"]="est" ["fi_FI"]="fin" ["fr_XX"]="fra" ["hi_IN"]="hin" ["it_IT"]="ita" ["ja_XX"]="jpn" ["ko_KR"]="kor" ["kk_KZ"]="kaz" ["nl_XX"]="nld" ["ru_RU"]="rus" ["tr_TR"]="tur" ["vi_VN"]="vie" ["zh_CN"]="cmn")
16
+ for lang in ar_AR de_DE es_XX et_EE fi_FI fr_XX hi_IN it_IT ja_XX kk_KZ ko_KR nl_XX ru_RU tr_TR vi_VN zh_CN; do
17
+ lang_tatoeba=${lang_tatoeba_map[$lang]}
18
+ echo $lang_tatoeba
19
+ datadir=$DATA/${lang}-en_XX-tatoeba
20
+ rm -rf $datadir
21
+ mkdir -p $datadir
22
+ TEST_PREFIX=LASER/data/tatoeba/v1/tatoeba
23
+ python $SPM_ENCODE \
24
+ --model ${SPM_MODEL} \
25
+ --output_format=piece \
26
+ --inputs ${TEST_PREFIX}.${lang_tatoeba}-eng.${lang_tatoeba} ${TEST_PREFIX}.${lang_tatoeba}-eng.eng \
27
+ --outputs $datadir/test.bpe.${lang}-en_XX.${lang} $datadir/test.bpe.${lang}-en_XX.en_XX
28
+
29
+ # binarize data
30
+ fairseq-preprocess \
31
+ --source-lang ${lang} --target-lang en_XX \
32
+ --testpref $datadir/test.bpe.${lang}-en_XX \
33
+ --destdir $datadir \
34
+ --srcdict ${DICT} \
35
+ --joined-dictionary \
36
+ --workers 4
37
+ done
fairseq-0.10.2/examples/criss/mining/mine.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import argparse
7
+ import glob
8
+ from subprocess import check_call
9
+
10
+ import faiss
11
+ import numpy as np
12
+
13
+
14
+ GB = 1024 * 1024 * 1024
15
+
16
+
17
+ def call(cmd):
18
+ print(cmd)
19
+ check_call(cmd, shell=True)
20
+
21
+
22
+ def get_batches(directory, lang, prefix="all_avg_pool"):
23
+ print(f"Finding in {directory}/{prefix}.{lang}*")
24
+ files = glob.glob(f"{directory}/{prefix}.{lang}*")
25
+ emb_files = []
26
+ txt_files = []
27
+ for emb_fi in files:
28
+ emb_files.append(emb_fi)
29
+ txt_fi = emb_fi.replace(prefix, "sentences")
30
+ txt_files.append(txt_fi)
31
+ return emb_files, txt_files
32
+
33
+
34
+ def load_batch(emb_file, dim):
35
+ embeddings = np.fromfile(emb_file, dtype=np.float32)
36
+ num_rows = int(embeddings.shape[0] / dim)
37
+ embeddings = embeddings.reshape((num_rows, dim))
38
+ faiss.normalize_L2(embeddings)
39
+ return embeddings
40
+
41
+
42
+ def knnGPU_sharded(x_batches_f, y_batches_f, dim, k, direction="x2y"):
43
+ sims = []
44
+ inds = []
45
+ xfrom = 0
46
+ xto = 0
47
+ for x_batch_f in x_batches_f:
48
+ yfrom = 0
49
+ yto = 0
50
+ x_batch = load_batch(x_batch_f, dim)
51
+ xto = xfrom + x_batch.shape[0]
52
+ bsims, binds = [], []
53
+ for y_batch_f in y_batches_f:
54
+ y_batch = load_batch(y_batch_f, dim)
55
+ neighbor_size = min(k, y_batch.shape[0])
56
+ yto = yfrom + y_batch.shape[0]
57
+ print("{}-{} -> {}-{}".format(xfrom, xto, yfrom, yto))
58
+ idx = faiss.IndexFlatIP(dim)
59
+ idx = faiss.index_cpu_to_all_gpus(idx)
60
+ idx.add(y_batch)
61
+ bsim, bind = idx.search(x_batch, neighbor_size)
62
+
63
+ bsims.append(bsim)
64
+ binds.append(bind + yfrom)
65
+ yfrom += y_batch.shape[0]
66
+ del idx
67
+ del y_batch
68
+ bsims = np.concatenate(bsims, axis=1)
69
+ binds = np.concatenate(binds, axis=1)
70
+ aux = np.argsort(-bsims, axis=1)
71
+ sim_batch = np.zeros((x_batch.shape[0], k), dtype=np.float32)
72
+ ind_batch = np.zeros((x_batch.shape[0], k), dtype=np.int64)
73
+ for i in range(x_batch.shape[0]):
74
+ for j in range(k):
75
+ sim_batch[i, j] = bsims[i, aux[i, j]]
76
+ ind_batch[i, j] = binds[i, aux[i, j]]
77
+ sims.append(sim_batch)
78
+ inds.append(ind_batch)
79
+ xfrom += x_batch.shape[0]
80
+ del x_batch
81
+ sim = np.concatenate(sims, axis=0)
82
+ ind = np.concatenate(inds, axis=0)
83
+ return sim, ind
84
+
85
+
86
+ def score(sim, fwd_mean, bwd_mean, margin):
87
+ return margin(sim, (fwd_mean + bwd_mean) / 2)
88
+
89
+
90
+ def score_candidates(
91
+ sim_mat, candidate_inds, fwd_mean, bwd_mean, margin, verbose=False
92
+ ):
93
+ print(" - scoring {:d} candidates".format(sim_mat.shape[0]))
94
+ scores = np.zeros(candidate_inds.shape)
95
+ for i in range(scores.shape[0]):
96
+ for j in range(scores.shape[1]):
97
+ k = int(candidate_inds[i, j])
98
+ scores[i, j] = score(sim_mat[i, j], fwd_mean[i], bwd_mean[k], margin)
99
+ return scores
100
+
101
+
102
+ def load_text(files):
103
+ all_sentences = []
104
+ for fi in files:
105
+ with open(fi) as sentence_fi:
106
+ for line in sentence_fi:
107
+ all_sentences.append(line.strip())
108
+ print(f"Read {len(all_sentences)} sentences")
109
+ return all_sentences
110
+
111
+
112
+ if __name__ == "__main__":
113
+ parser = argparse.ArgumentParser(description="Mine bitext")
114
+ parser.add_argument("--src-lang", help="Source language")
115
+ parser.add_argument("--tgt-lang", help="Target language")
116
+ parser.add_argument(
117
+ "--dict-path", help="Path to dictionary file", default="dict.txt"
118
+ )
119
+ parser.add_argument(
120
+ "--spm-path", help="Path to SPM model file", default="sentence.bpe.model"
121
+ )
122
+ parser.add_argument("--dim", type=int, default=1024, help="Embedding dimension")
123
+ parser.add_argument("--mem", type=int, default=5, help="Memory in GB")
124
+ parser.add_argument("--src-dir", help="Source directory")
125
+ parser.add_argument("--tgt-dir", help="Target directory")
126
+ parser.add_argument("--output", help="Output path")
127
+ parser.add_argument(
128
+ "--neighborhood", type=int, default=4, help="Embedding dimension"
129
+ )
130
+ parser.add_argument(
131
+ "--threshold", type=float, default=1.06, help="Threshold on mined bitext"
132
+ )
133
+ parser.add_argument(
134
+ "--valid-size",
135
+ type=int,
136
+ default=2000,
137
+ help="Number of sentences used for validation set",
138
+ )
139
+ parser.add_argument(
140
+ "--min-count",
141
+ type=int,
142
+ default=50000,
143
+ help="Min num sentences used for each language",
144
+ )
145
+ args = parser.parse_args()
146
+
147
+ x_batches_f, x_sents_f = get_batches(args.src_dir, args.src_lang)
148
+ y_batches_f, y_sents_f = get_batches(args.tgt_dir, args.tgt_lang)
149
+ margin = lambda a, b: a / b
150
+ y2x_sim, y2x_ind = knnGPU_sharded(
151
+ y_batches_f, x_batches_f, args.dim, args.neighborhood, direction="y2x"
152
+ )
153
+ x2y_sim, x2y_ind = knnGPU_sharded(
154
+ x_batches_f, y_batches_f, args.dim, args.neighborhood, direction="x2y"
155
+ )
156
+
157
+ x2y_mean = x2y_sim.mean(axis=1)
158
+ y2x_mean = y2x_sim.mean(axis=1)
159
+ fwd_scores = score_candidates(x2y_sim, x2y_ind, x2y_mean, y2x_mean, margin)
160
+ bwd_scores = score_candidates(y2x_sim, y2x_ind, y2x_mean, x2y_mean, margin)
161
+ fwd_best = x2y_ind[np.arange(x2y_sim.shape[0]), fwd_scores.argmax(axis=1)]
162
+ bwd_best = y2x_ind[np.arange(y2x_sim.shape[0]), bwd_scores.argmax(axis=1)]
163
+ indices = np.stack(
164
+ (
165
+ np.concatenate((np.arange(x2y_ind.shape[0]), bwd_best)),
166
+ np.concatenate((fwd_best, np.arange(y2x_ind.shape[0]))),
167
+ ),
168
+ axis=1,
169
+ )
170
+ scores = np.concatenate((fwd_scores.max(axis=1), bwd_scores.max(axis=1)))
171
+
172
+ x_sentences = load_text(x_sents_f)
173
+ y_sentences = load_text(y_sents_f)
174
+
175
+ threshold = args.threshold
176
+ min_count = args.min_count
177
+ seen_src, seen_trg = set(), set()
178
+ directory = args.output
179
+ call(f"mkdir -p {directory}")
180
+ src_out = open(
181
+ f"{directory}/all.{args.src_lang}",
182
+ mode="w",
183
+ encoding="utf-8",
184
+ errors="surrogateescape",
185
+ )
186
+ tgt_out = open(
187
+ f"{directory}/all.{args.tgt_lang}",
188
+ mode="w",
189
+ encoding="utf-8",
190
+ errors="surrogateescape",
191
+ )
192
+ scores_out = open(
193
+ f"{directory}/all.scores", mode="w", encoding="utf-8", errors="surrogateescape"
194
+ )
195
+ count = 0
196
+ for i in np.argsort(-scores):
197
+ src_ind, trg_ind = indices[i]
198
+ if src_ind not in seen_src and trg_ind not in seen_trg:
199
+ seen_src.add(src_ind)
200
+ seen_trg.add(trg_ind)
201
+ if scores[i] > threshold or count < min_count:
202
+ if x_sentences[src_ind]:
203
+ print(scores[i], file=scores_out)
204
+ print(x_sentences[src_ind], file=src_out)
205
+ print(y_sentences[trg_ind], file=tgt_out)
206
+ count += 1
207
+ else:
208
+ print(f"Ignoring sentence: {x_sentences[src_ind]}")
209
+ src_out.close()
210
+ tgt_out.close()
211
+ scores_out.close()
212
+
213
+ print(f"Found {count} pairs for threshold={threshold}")
214
+ with open(f"{directory}/all.{args.src_lang}") as all_s, open(
215
+ f"{directory}/all.{args.tgt_lang}"
216
+ ) as all_t, open(f"{directory}/valid.{args.src_lang}", "w") as valid_s, open(
217
+ f"{directory}/valid.{args.tgt_lang}", "w"
218
+ ) as valid_t, open(
219
+ f"{directory}/train.{args.src_lang}", "w"
220
+ ) as train_s, open(
221
+ f"{directory}/train.{args.tgt_lang}", "w"
222
+ ) as train_t:
223
+ count = 0
224
+ for s_line, t_line in zip(all_s, all_t):
225
+ s_line = s_line.split("\t")[1]
226
+ t_line = t_line.split("\t")[1]
227
+ if count >= args.valid_size:
228
+ train_s.write(s_line)
229
+ train_t.write(t_line)
230
+ else:
231
+ valid_s.write(s_line)
232
+ valid_t.write(t_line)
233
+ count += 1
fairseq-0.10.2/examples/criss/mining/mine_example.sh ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ source_lang=kk_KZ
9
+ target_lang=en_XX
10
+ MODEL=criss_checkpoints/criss.2nd.pt
11
+ SPM=criss_checkpoints/sentence.bpe.model
12
+ SPLIT=test
13
+ LANG_DICT=criss_checkpoints/lang_dict.txt
14
+ SPM_ENCODE=flores/scripts/spm_encode.py
15
+ SAVE_ENCODER=save_encoder.py
16
+ ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
17
+ DICT=criss_checkpoints/dict.txt
18
+ THRESHOLD=1.02
19
+ MIN_COUNT=500
20
+
21
+ DATA_DIR=data_tmp
22
+ SAVE_DIR=mining/${source_lang}_${target_lang}_mined
23
+ ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
24
+ INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
25
+
26
+ mkdir -p $ENCODER_SAVE_DIR/${target_lang}
27
+ mkdir -p $ENCODER_SAVE_DIR/${source_lang}
28
+ mkdir -p $SAVE_DIR
29
+
30
+ ## Save encoder outputs
31
+
32
+ # Save encoder outputs for source sentences
33
+ python $SAVE_ENCODER \
34
+ ${INPUT_DIR} \
35
+ --path ${MODEL} \
36
+ --task translation_multi_simple_epoch \
37
+ --lang-pairs ${source_lang}-${target_lang} \
38
+ --lang-dict ${LANG_DICT} \
39
+ --gen-subset ${SPLIT} \
40
+ --bpe 'sentencepiece' \
41
+ -s ${source_lang} -t ${target_lang} \
42
+ --sentencepiece-model ${SPM} \
43
+ --remove-bpe 'sentencepiece' \
44
+ --beam 1 \
45
+ --lang-tok-style mbart \
46
+ --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
47
+
48
+ ## Save encoder outputs for target sentences
49
+ python $SAVE_ENCODER \
50
+ ${INPUT_DIR} \
51
+ --path ${MODEL} \
52
+ --lang-pairs ${source_lang}-${target_lang} \
53
+ --lang-dict ${LANG_DICT} \
54
+ --task translation_multi_simple_epoch \
55
+ --gen-subset ${SPLIT} \
56
+ --bpe 'sentencepiece' \
57
+ -t ${source_lang} -s ${target_lang} \
58
+ --sentencepiece-model ${SPM} \
59
+ --remove-bpe 'sentencepiece' \
60
+ --beam 1 \
61
+ --lang-tok-style mbart \
62
+ --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
63
+
64
+ ## Mining
65
+ python mining/mine.py \
66
+ --src-lang ${source_lang} \
67
+ --tgt-lang ${target_lang} \
68
+ --dim 1024 \
69
+ --mem 10 \
70
+ --neighborhood 4 \
71
+ --src-dir ${ENCODER_SAVE_DIR}/${source_lang} \
72
+ --tgt-dir ${ENCODER_SAVE_DIR}/${target_lang} \
73
+ --output $SAVE_DIR \
74
+ --threshold ${THRESHOLD} \
75
+ --min-count ${MIN_COUNT} \
76
+ --valid-size 100 \
77
+ --dict-path ${DICT} \
78
+ --spm-path ${SPM} \
79
+
80
+
81
+ ## Process and binarize mined data
82
+ python $SPM_ENCODE \
83
+ --model ${SPM} \
84
+ --output_format=piece \
85
+ --inputs mining/${source_lang}_${target_lang}_mined/train.${source_lang} mining/${source_lang}_${target_lang}_mined/train.${target_lang} \
86
+ --outputs mining/${source_lang}_${target_lang}_mined/train.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/train.bpe.${target_lang}
87
+
88
+ python $SPM_ENCODE \
89
+ --model ${SPM} \
90
+ --output_format=piece \
91
+ --inputs mining/${source_lang}_${target_lang}_mined/valid.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.${target_lang} \
92
+ --outputs mining/${source_lang}_${target_lang}_mined/valid.bpe.${source_lang} mining/${source_lang}_${target_lang}_mined/valid.bpe.${target_lang}
93
+
94
+
95
+ fairseq-preprocess \
96
+ --source-lang ${source_lang} \
97
+ --target-lang ${target_lang} \
98
+ --trainpref mining/${source_lang}_${target_lang}_mined/train.bpe \
99
+ --validpref mining/${source_lang}_${target_lang}_mined/valid.bpe \
100
+ --destdir mining/${source_lang}_${target_lang}_mined \
101
+ --srcdict ${DICT} \
102
+ --joined-dictionary \
103
+ --workers 8
fairseq-0.10.2/examples/criss/save_encoder.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ Translate pre-processed data with a trained model.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
13
+ from fairseq.sequence_generator import EnsembleModel
14
+
15
+
16
+ def get_avg_pool(
17
+ models, sample, prefix_tokens, src_dict, remove_bpe, has_langtok=False
18
+ ):
19
+ model = EnsembleModel(models)
20
+
21
+ # model.forward normally channels prev_output_tokens into the decoder
22
+ # separately, but SequenceGenerator directly calls model.encoder
23
+ encoder_input = {
24
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
25
+ }
26
+
27
+ # compute the encoder output for each beam
28
+ encoder_outs = model.forward_encoder(encoder_input)
29
+ np_encoder_outs = encoder_outs[0].encoder_out.cpu().numpy().astype(np.float32)
30
+ encoder_mask = 1 - encoder_outs[0].encoder_padding_mask.cpu().numpy().astype(
31
+ np.float32
32
+ )
33
+ encoder_mask = np.expand_dims(encoder_mask.T, axis=2)
34
+ if has_langtok:
35
+ encoder_mask = encoder_mask[1:, :, :]
36
+ np_encoder_outs = np_encoder_outs[1, :, :]
37
+ masked_encoder_outs = encoder_mask * np_encoder_outs
38
+ avg_pool = (masked_encoder_outs / encoder_mask.sum(axis=0)).sum(axis=0)
39
+ return avg_pool
40
+
41
+
42
+ def main(args):
43
+ assert args.path is not None, "--path required for generation!"
44
+ assert (
45
+ not args.sampling or args.nbest == args.beam
46
+ ), "--sampling requires --nbest to be equal to --beam"
47
+ assert (
48
+ args.replace_unk is None or args.raw_text
49
+ ), "--replace-unk requires a raw text dataset (--raw-text)"
50
+
51
+ args.beam = 1
52
+ utils.import_user_module(args)
53
+
54
+ if args.max_tokens is None:
55
+ args.max_tokens = 12000
56
+ print(args)
57
+ use_cuda = torch.cuda.is_available() and not args.cpu
58
+
59
+ # Load dataset splits
60
+ task = tasks.setup_task(args)
61
+ task.load_dataset(args.gen_subset)
62
+
63
+ # Set dictionaries
64
+ try:
65
+ src_dict = getattr(task, "source_dictionary", None)
66
+ except NotImplementedError:
67
+ src_dict = None
68
+ tgt_dict = task.target_dictionary
69
+
70
+ # Load ensemble
71
+ print("| loading model(s) from {}".format(args.path))
72
+ models, _model_args = checkpoint_utils.load_model_ensemble(
73
+ args.path.split(":"),
74
+ arg_overrides=eval(args.model_overrides),
75
+ task=task,
76
+ )
77
+
78
+ # Optimize ensemble for generation
79
+ for model in models:
80
+ model.make_generation_fast_(
81
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
82
+ need_attn=args.print_alignment,
83
+ )
84
+ if args.fp16:
85
+ model.half()
86
+ if use_cuda:
87
+ model.cuda()
88
+
89
+ # Load alignment dictionary for unknown word replacement
90
+ # (None if no unknown word replacement, empty if no path to align dictionary)
91
+ align_dict = utils.load_align_dict(args.replace_unk)
92
+
93
+ # Load dataset (possibly sharded)
94
+ itr = task.get_batch_iterator(
95
+ dataset=task.dataset(args.gen_subset),
96
+ max_tokens=args.max_tokens,
97
+ max_positions=utils.resolve_max_positions(
98
+ task.max_positions(),
99
+ ),
100
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
101
+ required_batch_size_multiple=args.required_batch_size_multiple,
102
+ num_shards=args.num_shards,
103
+ shard_id=args.shard_id,
104
+ num_workers=args.num_workers,
105
+ ).next_epoch_itr(shuffle=False)
106
+
107
+ num_sentences = 0
108
+ source_sentences = []
109
+ shard_id = 0
110
+ all_avg_pool = None
111
+ encoder_has_langtok = (
112
+ hasattr(task.args, "encoder_langtok")
113
+ and task.args.encoder_langtok is not None
114
+ and hasattr(task.args, "lang_tok_replacing_bos_eos")
115
+ and not task.args.lang_tok_replacing_bos_eos
116
+ )
117
+ with progress_bar.build_progress_bar(args, itr) as t:
118
+ for sample in t:
119
+ if sample is None:
120
+ print("Skipping None")
121
+ continue
122
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
123
+ if "net_input" not in sample:
124
+ continue
125
+
126
+ prefix_tokens = None
127
+ if args.prefix_size > 0:
128
+ prefix_tokens = sample["target"][:, : args.prefix_size]
129
+
130
+ with torch.no_grad():
131
+ avg_pool = get_avg_pool(
132
+ models,
133
+ sample,
134
+ prefix_tokens,
135
+ src_dict,
136
+ args.remove_bpe,
137
+ has_langtok=encoder_has_langtok,
138
+ )
139
+ if all_avg_pool is not None:
140
+ all_avg_pool = np.concatenate((all_avg_pool, avg_pool))
141
+ else:
142
+ all_avg_pool = avg_pool
143
+
144
+ if not isinstance(sample["id"], list):
145
+ sample_ids = sample["id"].tolist()
146
+ else:
147
+ sample_ids = sample["id"]
148
+ for i, sample_id in enumerate(sample_ids):
149
+ # Remove padding
150
+ src_tokens = utils.strip_pad(
151
+ sample["net_input"]["src_tokens"][i, :], tgt_dict.pad()
152
+ )
153
+
154
+ # Either retrieve the original sentences or regenerate them from tokens.
155
+ if align_dict is not None:
156
+ src_str = task.dataset(args.gen_subset).src.get_original_text(
157
+ sample_id
158
+ )
159
+ else:
160
+ if src_dict is not None:
161
+ src_str = src_dict.string(src_tokens, args.remove_bpe)
162
+ else:
163
+ src_str = ""
164
+
165
+ if not args.quiet:
166
+ if src_dict is not None:
167
+ print("S-{}\t{}".format(sample_id, src_str))
168
+
169
+ source_sentences.append(f"{sample_id}\t{src_str}")
170
+
171
+ num_sentences += sample["nsentences"]
172
+ if all_avg_pool.shape[0] >= 1000000:
173
+ with open(
174
+ f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}",
175
+ "w",
176
+ ) as avg_pool_file:
177
+ all_avg_pool.tofile(avg_pool_file)
178
+ with open(
179
+ f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}",
180
+ "w",
181
+ ) as sentence_file:
182
+ sentence_file.writelines(f"{line}\n" for line in source_sentences)
183
+ all_avg_pool = None
184
+ source_sentences = []
185
+ shard_id += 1
186
+
187
+ if all_avg_pool is not None:
188
+ with open(
189
+ f"{args.encoder_save_dir}/all_avg_pool.{args.source_lang}.{shard_id}", "w"
190
+ ) as avg_pool_file:
191
+ all_avg_pool.tofile(avg_pool_file)
192
+ with open(
193
+ f"{args.encoder_save_dir}/sentences.{args.source_lang}.{shard_id}", "w"
194
+ ) as sentence_file:
195
+ sentence_file.writelines(f"{line}\n" for line in source_sentences)
196
+ return None
197
+
198
+
199
+ def cli_main():
200
+ parser = options.get_generation_parser()
201
+ parser.add_argument(
202
+ "--encoder-save-dir",
203
+ default="",
204
+ type=str,
205
+ metavar="N",
206
+ help="directory to save encoder outputs",
207
+ )
208
+ args = options.parse_args_and_arch(parser)
209
+ main(args)
210
+
211
+
212
+ if __name__ == "__main__":
213
+ cli_main()
fairseq-0.10.2/examples/criss/sentence_retrieval/encoder_analysis.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ import argparse
7
+ import glob
8
+
9
+ import numpy as np
10
+
11
+
12
+ DIM = 1024
13
+
14
+
15
+ def compute_dist(source_embs, target_embs, k=5, return_sim_mat=False):
16
+ target_ids = [tid for tid in target_embs]
17
+ source_mat = np.stack(source_embs.values(), axis=0)
18
+ normalized_source_mat = source_mat / np.linalg.norm(
19
+ source_mat, axis=1, keepdims=True
20
+ )
21
+ target_mat = np.stack(target_embs.values(), axis=0)
22
+ normalized_target_mat = target_mat / np.linalg.norm(
23
+ target_mat, axis=1, keepdims=True
24
+ )
25
+ sim_mat = normalized_source_mat.dot(normalized_target_mat.T)
26
+ if return_sim_mat:
27
+ return sim_mat
28
+ neighbors_map = {}
29
+ for i, sentence_id in enumerate(source_embs):
30
+ idx = np.argsort(sim_mat[i, :])[::-1][:k]
31
+ neighbors_map[sentence_id] = [target_ids[tid] for tid in idx]
32
+ return neighbors_map
33
+
34
+
35
+ def load_embeddings(directory, LANGS):
36
+ sentence_embeddings = {}
37
+ sentence_texts = {}
38
+ for lang in LANGS:
39
+ sentence_embeddings[lang] = {}
40
+ sentence_texts[lang] = {}
41
+ lang_dir = f"{directory}/{lang}"
42
+ embedding_files = glob.glob(f"{lang_dir}/all_avg_pool.{lang}.*")
43
+ for embed_file in embedding_files:
44
+ shard_id = embed_file.split(".")[-1]
45
+ embeddings = np.fromfile(embed_file, dtype=np.float32)
46
+ num_rows = embeddings.shape[0] // DIM
47
+ embeddings = embeddings.reshape((num_rows, DIM))
48
+
49
+ with open(f"{lang_dir}/sentences.{lang}.{shard_id}") as sentence_file:
50
+ for idx, line in enumerate(sentence_file):
51
+ sentence_id, sentence = line.strip().split("\t")
52
+ sentence_texts[lang][sentence_id] = sentence
53
+ sentence_embeddings[lang][sentence_id] = embeddings[idx, :]
54
+
55
+ return sentence_embeddings, sentence_texts
56
+
57
+
58
+ def compute_accuracy(directory, LANGS):
59
+ sentence_embeddings, sentence_texts = load_embeddings(directory, LANGS)
60
+
61
+ top_1_accuracy = {}
62
+
63
+ top1_str = " ".join(LANGS) + "\n"
64
+ for source_lang in LANGS:
65
+ top_1_accuracy[source_lang] = {}
66
+ top1_str += f"{source_lang} "
67
+ for target_lang in LANGS:
68
+ top1 = 0
69
+ top5 = 0
70
+ neighbors_map = compute_dist(
71
+ sentence_embeddings[source_lang], sentence_embeddings[target_lang]
72
+ )
73
+ for sentence_id, neighbors in neighbors_map.items():
74
+ if sentence_id == neighbors[0]:
75
+ top1 += 1
76
+ if sentence_id in neighbors[:5]:
77
+ top5 += 1
78
+ n = len(sentence_embeddings[target_lang])
79
+ top1_str += f"{top1/n} "
80
+ top1_str += "\n"
81
+
82
+ print(top1_str)
83
+ print(top1_str, file=open(f"{directory}/accuracy", "w"))
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser(description="Analyze encoder outputs")
88
+ parser.add_argument("directory", help="Source language corpus")
89
+ parser.add_argument("--langs", help="List of langs")
90
+ args = parser.parse_args()
91
+ langs = args.langs.split(",")
92
+ compute_accuracy(args.directory, langs)
fairseq-0.10.2/examples/criss/sentence_retrieval/sentence_retrieval_tatoeba.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ source_lang=kk_KZ
9
+ target_lang=en_XX
10
+ MODEL=criss_checkpoints/criss.3rd.pt
11
+ SPM=criss_checkpoints/sentence.bpe.model
12
+ SPLIT=test
13
+ LANG_DICT=criss_checkpoints/lang_dict.txt
14
+ ENCODER_ANALYSIS=sentence_retrieval/encoder_analysis.py
15
+ SAVE_ENCODER=save_encoder.py
16
+ ENCODER_SAVE_ROOT=sentence_embeddings/$MODEL
17
+
18
+
19
+
20
+ DATA_DIR=data_tmp
21
+ INPUT_DIR=$DATA_DIR/${source_lang}-${target_lang}-tatoeba
22
+ ENCODER_SAVE_DIR=${ENCODER_SAVE_ROOT}/${source_lang}-${target_lang}
23
+ mkdir -p $ENCODER_SAVE_DIR/${target_lang}
24
+ mkdir -p $ENCODER_SAVE_DIR/${source_lang}
25
+
26
+ # Save encoder outputs for source sentences
27
+ python $SAVE_ENCODER \
28
+ ${INPUT_DIR} \
29
+ --path ${MODEL} \
30
+ --task translation_multi_simple_epoch \
31
+ --lang-dict ${LANG_DICT} \
32
+ --gen-subset ${SPLIT} \
33
+ --bpe 'sentencepiece' \
34
+ --lang-pairs ${source_lang}-${target_lang} \
35
+ -s ${source_lang} -t ${target_lang} \
36
+ --sentencepiece-model ${SPM} \
37
+ --remove-bpe 'sentencepiece' \
38
+ --beam 1 \
39
+ --lang-tok-style mbart \
40
+ --encoder-save-dir ${ENCODER_SAVE_DIR}/${source_lang}
41
+
42
+ # Save encoder outputs for target sentences
43
+ python $SAVE_ENCODER \
44
+ ${INPUT_DIR} \
45
+ --path ${MODEL} \
46
+ --lang-dict ${LANG_DICT} \
47
+ --task translation_multi_simple_epoch \
48
+ --gen-subset ${SPLIT} \
49
+ --bpe 'sentencepiece' \
50
+ --lang-pairs ${target_lang}-${source_lang} \
51
+ -t ${source_lang} -s ${target_lang} \
52
+ --sentencepiece-model ${SPM} \
53
+ --remove-bpe 'sentencepiece' \
54
+ --beam 1 \
55
+ --lang-tok-style mbart \
56
+ --encoder-save-dir ${ENCODER_SAVE_DIR}/${target_lang}
57
+
58
+ # Analyze sentence retrieval accuracy
59
+ python $ENCODER_ANALYSIS --langs "${source_lang},${target_lang}" ${ENCODER_SAVE_DIR}
fairseq-0.10.2/examples/criss/unsupervised_mt/eval.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+ SRC=si_LK
9
+ TGT=en_XX
10
+ MODEL=criss_checkpoints/criss.3rd.pt
11
+
12
+ MULTIBLEU=mosesdecoder/scripts/generic/multi-bleu.perl
13
+ MOSES=mosesdecoder
14
+ REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
15
+ NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
16
+ REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
17
+ TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
18
+ GEN_TMP_DIR=gen_tmp
19
+ LANG_DICT=criss_checkpoints/lang_dict.txt
20
+
21
+ if [ ! -d "mosesdecoder" ]; then
22
+ git clone https://github.com/moses-smt/mosesdecoder
23
+ fi
24
+ mkdir -p $GEN_TMP_DIR
25
+ fairseq-generate data_tmp/${SRC}-${TGT}-flores \
26
+ --task translation_multi_simple_epoch \
27
+ --max-tokens 2000 \
28
+ --path ${MODEL} \
29
+ --skip-invalid-size-inputs-valid-test \
30
+ --beam 5 --lenpen 1.0 --gen-subset test \
31
+ --remove-bpe=sentencepiece \
32
+ --source-lang ${SRC} --target-lang ${TGT} \
33
+ --decoder-langtok --lang-pairs 'en_XX-ar_AR,en_XX-de_DE,en_XX-es_XX,en_XX-fr_XX,en_XX-hi_IN,en_XX-it_IT,en_XX-ja_XX,en_XX-ko_KR,en_XX-nl_XX,en_XX-ru_RU,en_XX-zh_CN,en_XX-tr_TR,en_XX-vi_VN,en_XX-ro_RO,en_XX-my_MM,en_XX-ne_NP,en_XX-si_LK,en_XX-cs_CZ,en_XX-lt_LT,en_XX-kk_KZ,en_XX-gu_IN,en_XX-fi_FI,en_XX-et_EE,en_XX-lv_LV,ar_AR-en_XX,cs_CZ-en_XX,de_DE-en_XX,es_XX-en_XX,et_EE-en_XX,fi_FI-en_XX,fr_XX-en_XX,gu_IN-en_XX,hi_IN-en_XX,it_IT-en_XX,ja_XX-en_XX,kk_KZ-en_XX,ko_KR-en_XX,lt_LT-en_XX,lv_LV-en_XX,my_MM-en_XX,ne_NP-en_XX,nl_XX-en_XX,ro_RO-en_XX,ru_RU-en_XX,si_LK-en_XX,tr_TR-en_XX,vi_VN-en_XX,zh_CN-en_XX,ar_AR-es_XX,es_XX-ar_AR,ar_AR-hi_IN,hi_IN-ar_AR,ar_AR-zh_CN,zh_CN-ar_AR,cs_CZ-es_XX,es_XX-cs_CZ,cs_CZ-hi_IN,hi_IN-cs_CZ,cs_CZ-zh_CN,zh_CN-cs_CZ,de_DE-es_XX,es_XX-de_DE,de_DE-hi_IN,hi_IN-de_DE,de_DE-zh_CN,zh_CN-de_DE,es_XX-hi_IN,hi_IN-es_XX,es_XX-zh_CN,zh_CN-es_XX,et_EE-es_XX,es_XX-et_EE,et_EE-hi_IN,hi_IN-et_EE,et_EE-zh_CN,zh_CN-et_EE,fi_FI-es_XX,es_XX-fi_FI,fi_FI-hi_IN,hi_IN-fi_FI,fi_FI-zh_CN,zh_CN-fi_FI,fr_XX-es_XX,es_XX-fr_XX,fr_XX-hi_IN,hi_IN-fr_XX,fr_XX-zh_CN,zh_CN-fr_XX,gu_IN-es_XX,es_XX-gu_IN,gu_IN-hi_IN,hi_IN-gu_IN,gu_IN-zh_CN,zh_CN-gu_IN,hi_IN-zh_CN,zh_CN-hi_IN,it_IT-es_XX,es_XX-it_IT,it_IT-hi_IN,hi_IN-it_IT,it_IT-zh_CN,zh_CN-it_IT,ja_XX-es_XX,es_XX-ja_XX,ja_XX-hi_IN,hi_IN-ja_XX,ja_XX-zh_CN,zh_CN-ja_XX,kk_KZ-es_XX,es_XX-kk_KZ,kk_KZ-hi_IN,hi_IN-kk_KZ,kk_KZ-zh_CN,zh_CN-kk_KZ,ko_KR-es_XX,es_XX-ko_KR,ko_KR-hi_IN,hi_IN-ko_KR,ko_KR-zh_CN,zh_CN-ko_KR,lt_LT-es_XX,es_XX-lt_LT,lt_LT-hi_IN,hi_IN-lt_LT,lt_LT-zh_CN,zh_CN-lt_LT,lv_LV-es_XX,es_XX-lv_LV,lv_LV-hi_IN,hi_IN-lv_LV,lv_LV-zh_CN,zh_CN-lv_LV,my_MM-es_XX,es_XX-my_MM,my_MM-hi_IN,hi_IN-my_MM,my_MM-zh_CN,zh_CN-my_MM,ne_NP-es_XX,es_XX-ne_NP,ne_NP-hi_IN,hi_IN-ne_NP,ne_NP-zh_CN,zh_CN-ne_NP,nl_XX-es_XX,es_XX-nl_XX,nl_XX-hi_IN,hi_IN-nl_XX,nl_XX-zh_CN,zh_CN-nl_XX,ro_RO-es_XX,es_XX-ro_RO,ro_RO-hi_IN,hi_IN-ro_RO,ro_RO-zh_CN,zh_CN-ro_RO,ru_RU-es_XX,es_XX-ru_RU,ru_RU-hi_IN,hi_IN-ru_RU,ru_RU-zh_CN,zh_CN-ru_RU,si_LK-es_XX,es_XX-si_LK,si_LK-hi_IN,hi_IN-si_LK,si_LK-zh_CN,zh_CN-si_LK,tr_TR-es_XX,es_XX-tr_TR,tr_TR-hi_IN,hi_IN-tr_TR,tr_TR-zh_CN,zh_CN-tr_TR,vi_VN-es_XX,es_XX-vi_VN,vi_VN-hi_IN,hi_IN-vi_VN,vi_VN-zh_CN,zh_CN-vi_VN' \
34
+ --lang-dict ${LANG_DICT} --lang-tok-style 'mbart' --sampling-method 'temperature' --sampling-temperature '1.0' > $GEN_TMP_DIR/${SRC}_${TGT}.gen
35
+ cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^T-" | cut -f2 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.hyp
36
+ cat $GEN_TMP_DIR/${SRC}_${TGT}.gen | grep -P "^H-" | cut -f3 | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l ${TGT:0:2} | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape ${TGT:0:2} > $GEN_TMP_DIR/${SRC}_${TGT}.ref
37
+ ${MULTIBLEU} $GEN_TMP_DIR/${SRC}_${TGT}.ref < $GEN_TMP_DIR/${SRC}_${TGT}.hyp
fairseq-0.10.2/examples/cross_lingual_language_model/README.md ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Cross-Lingual Language Model Pre-training
2
+
3
+ Below are some details for training Cross-Lingual Language Models (XLM) - similar to the ones presented in [Lample & Conneau, 2019](https://arxiv.org/pdf/1901.07291.pdf) - in Fairseq. The current implementation only supports the Masked Language Model (MLM) from the paper above.
4
+
5
+ ## Downloading and Tokenizing Monolingual Data
6
+
7
+ Pointers to the monolingual data from wikipedia, used for training the XLM-style MLM model as well as details on processing (tokenization and BPE) it can be found in the [XLM Github Repository](https://github.com/facebookresearch/XLM#download--preprocess-monolingual-data).
8
+
9
+ Let's assume the following for the code snippets in later sections to work
10
+ - Processed data is in the folder: monolingual_data/processed
11
+ - Each language has 3 files for train, test and validation. For example we have the following files for English:
12
+ train.en, valid.en
13
+ - We are training a model for 5 languages: Arabic (ar), German (de), English (en), Hindi (hi) and French (fr)
14
+ - The vocabulary file is monolingual_data/processed/vocab_mlm
15
+
16
+
17
+ ## Fairseq Pre-processing and Binarization
18
+
19
+ Pre-process and binarize the data with the MaskedLMDictionary and cross_lingual_lm task
20
+
21
+ ```bash
22
+ # Ensure the output directory exists
23
+ DATA_DIR=monolingual_data/fairseq_processed
24
+ mkdir -p "$DATA_DIR"
25
+
26
+ for lg in ar de en hi fr
27
+ do
28
+
29
+ fairseq-preprocess \
30
+ --task cross_lingual_lm \
31
+ --srcdict monolingual_data/processed/vocab_mlm \
32
+ --only-source \
33
+ --trainpref monolingual_data/processed/train \
34
+ --validpref monolingual_data/processed/valid \
35
+ --testpref monolingual_data/processed/test \
36
+ --destdir monolingual_data/fairseq_processed \
37
+ --workers 20 \
38
+ --source-lang $lg
39
+
40
+ # Since we only have a source language, the output file has a None for the
41
+ # target language. Remove this
42
+
43
+ for stage in train test valid
44
+
45
+ sudo mv "$DATA_DIR/$stage.$lg-None.$lg.bin" "$stage.$lg.bin"
46
+ sudo mv "$DATA_DIR/$stage.$lg-None.$lg.idx" "$stage.$lg.idx"
47
+
48
+ done
49
+
50
+ done
51
+ ```
52
+
53
+ ## Train a Cross-lingual Language Model similar to the XLM MLM model
54
+
55
+ Use the following command to train the model on 5 languages.
56
+
57
+ ```
58
+ fairseq-train \
59
+ --task cross_lingual_lm monolingual_data/fairseq_processed \
60
+ --save-dir checkpoints/mlm \
61
+ --max-update 2400000 --save-interval 1 --no-epoch-checkpoints \
62
+ --arch xlm_base \
63
+ --optimizer adam --lr-scheduler reduce_lr_on_plateau \
64
+ --lr-shrink 0.5 --lr 0.0001 --min-lr 1e-09 \
65
+ --dropout 0.1 \
66
+ --criterion legacy_masked_lm_loss \
67
+ --max-tokens 2048 --tokens-per-sample 256 --attention-dropout 0.1 \
68
+ --dataset-impl lazy --seed 0 \
69
+ --masked-lm-only \
70
+ --monolingual-langs 'ar,de,en,hi,fr' --num-segment 5 \
71
+ --ddp-backend=no_c10d
72
+ ```
73
+
74
+ Some Notes:
75
+ - Using tokens_per_sample greater than 256 can cause OOM (out-of-memory) issues. Usually since MLM packs in streams of text, this parameter doesn't need much tuning.
76
+ - The Evaluation workflow for computing MLM Perplexity on test data is in progress.
77
+ - Finetuning this model on a downstream task is something which is not currently available.
fairseq-0.10.2/examples/latent_depth/latent_depth_src/loss/latent_depth.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ from torch.nn.modules.loss import _Loss
10
+
11
+
12
+ class LatentLayersKLLoss(_Loss):
13
+ def __init__(self, args):
14
+ super().__init__()
15
+ self.args = args
16
+
17
+ def forward(self, layer_samples, lang_idx, update_num, sample_size):
18
+ prior = self.args.prior
19
+ samples = layer_samples[lang_idx]
20
+ eps = 1e-7
21
+ if prior == "uniform":
22
+ # uniform prior
23
+ kl_loss = (samples * (torch.log(samples + eps) - math.log(0.5))).sum(-1)
24
+ elif prior == "agged_posterior":
25
+ # aggregated posterior
26
+ y_t = torch.stack([x.detach() for x in layer_samples], dim=0)
27
+ agged_q = torch.sum(y_t, dim=0)
28
+ row_norm = agged_q.sum(-1)
29
+ normed_agg_q = agged_q / row_norm
30
+ kl_loss = (
31
+ samples * (torch.log(samples + eps) - torch.log(normed_agg_q + eps))
32
+ ).sum(-1)
33
+ else:
34
+ raise NotImplementedError("The specified prior is not implemented.")
35
+
36
+ # normalized by number of layers
37
+ kl_loss /= layer_samples[0].size()[0]
38
+ kl_weight = min(
39
+ self.args.sparsity_weight,
40
+ (update_num - self.args.soft_update)
41
+ * self.args.sparsity_weight
42
+ / self.args.anneal_updates,
43
+ )
44
+ kl_loss *= kl_weight * sample_size
45
+ return kl_loss
46
+
47
+
48
+ class LatentLayersSparsityLoss(_Loss):
49
+ def __init__(self, args):
50
+ super().__init__()
51
+ self.args = args
52
+
53
+ def is_valid(self, update_num):
54
+ if self.args.target_layers <= 0:
55
+ return False
56
+ return update_num > (self.args.soft_update + self.args.anneal_updates)
57
+
58
+ def forward(self, layer_samples_list, update_num, sample_size):
59
+ batch_loss = 0
60
+ share_loss = 0
61
+ global_sparsity_loss = 0
62
+ layer_samples = torch.stack(layer_samples_list, dim=0)
63
+ if (
64
+ self.args.target_layers > 0 or self.args.share_weight > 0
65
+ ) and update_num > (self.args.soft_update + self.args.anneal_updates):
66
+ # anneal sparsity weight
67
+ if update_num < (self.args.anneal_updates + self.args.soft_update):
68
+ weight_anneal = 0
69
+ elif update_num < (2 * self.args.anneal_updates + self.args.soft_update):
70
+ weight_anneal = (
71
+ (update_num - self.args.soft_update - self.args.anneal_updates)
72
+ * self.args.share_weight
73
+ / self.args.anneal_updates
74
+ )
75
+ else:
76
+ weight_anneal = 1
77
+ # compute ratio among languages
78
+ layer_utilization = torch.sum(layer_samples, dim=0)
79
+ layer_utilization /= layer_samples.size()[0]
80
+ if self.args.share_weight > 0:
81
+ # encouraging sharing across languages
82
+ share_loss = sum(
83
+ -1.0 * v * math.log(v) for v in layer_utilization if v > 0
84
+ )
85
+ batch_loss += (
86
+ weight_anneal * self.args.share_weight * sample_size * share_loss
87
+ )
88
+ if self.args.target_layers > 0:
89
+ # computed expected number of layers selected
90
+ expeted_layers = sum(layer_utilization)
91
+ # compute l2 loss wrt target number of layers
92
+ global_sparsity_loss = (expeted_layers - self.args.target_layers) ** 2
93
+ batch_loss += (
94
+ weight_anneal
95
+ * self.args.share_weight
96
+ * sample_size
97
+ * global_sparsity_loss
98
+ )
99
+ return batch_loss
fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/__init__.py ADDED
File without changes
fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_multilingual_transformer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq.models import register_model, register_model_architecture
7
+ from fairseq.models.multilingual_transformer import MultilingualTransformerModel
8
+ from fairseq.models.transformer import (
9
+ TransformerDecoder,
10
+ TransformerEncoder,
11
+ base_architecture,
12
+ )
13
+
14
+ from .latent_transformer import LatentTransformerDecoder, LatentTransformerEncoder
15
+
16
+
17
+ @register_model("latent_multilingual_transformer")
18
+ class LatentMultilingualTransformerModel(MultilingualTransformerModel):
19
+ """A variant of standard multilingual Transformer models which encoder and/or
20
+ decoders supports latent depth, as is in "Deep Transformer with Latent Depth"
21
+ (https://arxiv.org/abs/2009.13102).
22
+ """
23
+
24
+ @classmethod
25
+ def _get_module_class(cls, is_encoder, args, lang_dict, embed_tokens, langs):
26
+ if is_encoder:
27
+ if hasattr(args, "encoder_latent_layer") and args.encoder_latent_layer:
28
+ return LatentTransformerEncoder(
29
+ args, lang_dict, embed_tokens, num_logits=len(langs)
30
+ )
31
+ else:
32
+ return TransformerEncoder(args, lang_dict, embed_tokens)
33
+ else:
34
+ if hasattr(args, "decoder_latent_layer") and args.decoder_latent_layer:
35
+ return LatentTransformerDecoder(
36
+ args, lang_dict, embed_tokens, num_logits=len(langs)
37
+ )
38
+ else:
39
+ return TransformerDecoder(args, lang_dict, embed_tokens)
40
+
41
+
42
+ @register_model_architecture(
43
+ "latent_multilingual_transformer", "latent_multilingual_transformer"
44
+ )
45
+ def latent_multilingual_architecture(args):
46
+ args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
47
+ args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
48
+ args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
49
+ args.encoder_layers = getattr(args, "encoder_layers", 12)
50
+ args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
51
+ args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
52
+ args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
53
+ args.decoder_layers = getattr(args, "decoder_layers", 24)
54
+ args.share_encoders = getattr(args, "share_encoders", True)
55
+ args.share_decoders = getattr(args, "share_decoders", True)
56
+ args.share_encoder_embeddings = getattr(args, "share_encoder_embeddings", True)
57
+ args.share_decoder_embeddings = getattr(args, "share_decoder_embeddings", True)
58
+
59
+ base_architecture(args)
fairseq-0.10.2/examples/latent_depth/latent_depth_src/models/latent_transformer.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Dict, Optional
7
+
8
+ import torch.nn as nn
9
+ from fairseq.models.fairseq_encoder import EncoderOut
10
+ from fairseq.models.transformer import TransformerDecoder, TransformerEncoder
11
+ from fairseq.modules import TransformerDecoderLayer, TransformerEncoderLayer
12
+ from torch import Tensor
13
+
14
+ from ..modules.latent_layers import LayerSelect
15
+
16
+
17
+ class LatentTransformerEncoder(TransformerEncoder):
18
+ """Latent depth (https://arxiv.org/abs/2009.13102) implemented in
19
+ TransformerEncoder.
20
+ """
21
+
22
+ def __init__(self, args, dictionary, embed_tokens, num_logits=1):
23
+ self.num_logits = num_logits
24
+ self.num_layers = args.encoder_layers
25
+ super().__init__(args, dictionary, embed_tokens)
26
+ self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
27
+ self.lang_idx = None
28
+ self.layers = nn.ModuleList(
29
+ [self._build_encoder_layer(args, idx) for idx in range(args.encoder_layers)]
30
+ )
31
+
32
+ def set_lang_idx(self, lang_idx):
33
+ self.lang_idx = lang_idx
34
+
35
+ def _build_encoder_layer(self, args, idx=None):
36
+ return LatentTransformerEncoderLayer(args, idx, layer_select=self.layer_select)
37
+
38
+ def forward(self, src_tokens, src_lengths, return_all_hiddens: bool = False):
39
+ self.layer_select.sample(self.lang_idx)
40
+ return super().forward(src_tokens, src_lengths, return_all_hiddens)
41
+
42
+
43
+ class LatentTransformerEncoderLayer(TransformerEncoderLayer):
44
+ """Encoder layer with each (non_residual) block weighted by samples of Bernouli
45
+ or Gumbel Signmoid samples.
46
+
47
+ Args:
48
+ args (argparse.Namespace): parsed command-line arguments from standard
49
+ TransformerEncoderLayer.
50
+ idx (int): layer index (used to retrieve samples).
51
+ layer_select (LayerSelect, optional): instance of LayerSelect module with logits
52
+ parameters and sampling method.
53
+ """
54
+
55
+ def __init__(self, args, idx, layer_select=None):
56
+ super().__init__(args)
57
+ self.idx = idx
58
+ self.layer_select = layer_select
59
+
60
+ def residual_connection(self, x, residual):
61
+ return residual + x * self.layer_select(self.idx)
62
+
63
+
64
+ class LatentTransformerDecoder(TransformerDecoder):
65
+ """Latent depth (https://arxiv.org/abs/2009.13102) implemented in
66
+ TransformerDecoder.
67
+ """
68
+
69
+ def __init__(
70
+ self, args, dictionary, embed_tokens, no_encoder_attn=False, num_logits=1
71
+ ):
72
+ self.num_logits = num_logits
73
+ self.num_layers = args.decoder_layers
74
+ super().__init__(
75
+ args, dictionary, embed_tokens, no_encoder_attn=no_encoder_attn
76
+ )
77
+ self.layer_select = LayerSelect(self.num_layers, self.num_logits, args)
78
+ self.lang_idx = None
79
+ self.layers = nn.ModuleList(
80
+ [
81
+ self._build_decoder_layer(args, no_encoder_attn, idx)
82
+ for idx in range(args.decoder_layers)
83
+ ]
84
+ )
85
+
86
+ def set_lang_idx(self, lang_idx):
87
+ self.lang_idx = lang_idx
88
+
89
+ def _build_decoder_layer(self, args, no_encoder_attn=False, idx=None):
90
+ return LatentTransformerDecoderLayer(
91
+ args, idx, layer_select=self.layer_select, no_encoder_attn=no_encoder_attn
92
+ )
93
+
94
+ def forward(
95
+ self,
96
+ prev_output_tokens,
97
+ encoder_out: Optional[EncoderOut] = None,
98
+ incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
99
+ features_only: bool = False,
100
+ alignment_layer: Optional[int] = None,
101
+ alignment_heads: Optional[int] = None,
102
+ src_lengths: Optional[Any] = None,
103
+ return_all_hiddens: bool = False,
104
+ ):
105
+ self.layer_select.sample(self.lang_idx)
106
+ return super().forward(
107
+ prev_output_tokens=prev_output_tokens,
108
+ encoder_out=encoder_out,
109
+ incremental_state=incremental_state,
110
+ features_only=features_only,
111
+ alignment_layer=alignment_layer,
112
+ src_lengths=src_lengths,
113
+ return_all_hiddens=return_all_hiddens,
114
+ )
115
+
116
+
117
+ class LatentTransformerDecoderLayer(TransformerDecoderLayer):
118
+ """Decoder layer with each (non_residual) block weighted by samples of Bernouli
119
+ or Gumbel Signmoid samples.
120
+
121
+ Args:
122
+ args (argparse.Namespace): parsed command-line arguments from standard
123
+ TransformerDecoderLayer.
124
+ idx (int): layer index (used to retrieve samples).
125
+ layer_select (LayerSelect, optional): instance of LayerSelect module with logits
126
+ parameters and sampling method.
127
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
128
+ (default: False).
129
+
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ args,
135
+ idx,
136
+ layer_select=None,
137
+ no_encoder_attn=False,
138
+ add_bias_kv=False,
139
+ add_zero_attn=False,
140
+ ):
141
+ super().__init__(args, no_encoder_attn, add_bias_kv, add_zero_attn)
142
+ self.idx = idx
143
+ self.layer_select = layer_select
144
+
145
+ def residual_connection(self, x, residual):
146
+ return residual + x * self.layer_select(self.idx)
fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/__init__.py ADDED
File without changes
fairseq-0.10.2/examples/latent_depth/latent_depth_src/modules/latent_layers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+
10
+ class LayerSelect(nn.Module):
11
+ """Compute samples (from a Gumbel-Sigmoid distribution) which is used as
12
+ either (soft) weighting or (hard) selection of residual connection.
13
+ https://arxiv.org/abs/2009.13102
14
+ """
15
+
16
+ def __init__(self, num_layers, num_logits, args):
17
+ super(LayerSelect, self).__init__()
18
+ self.args = args
19
+ self.layer_logits = torch.nn.Parameter(
20
+ torch.Tensor(num_logits, num_layers),
21
+ requires_grad=True,
22
+ )
23
+ self.hard_select = not (hasattr(args, "soft_select") and args.soft_select)
24
+ self.tau = getattr(args, "sampling_tau", 5)
25
+ self.detach_grad = False
26
+ self.layer_samples = [None] * num_logits
27
+
28
+ @staticmethod
29
+ def add_args(parser):
30
+ parser.add_argument(
31
+ "--soft-select",
32
+ action="store_true",
33
+ help="use soft samples in training an inference",
34
+ )
35
+ parser.add_argument("--sampling-tau", type=float, help="sampling temperature")
36
+
37
+ def sample(self, logit_idx):
38
+ """To leverage the efficiency of distributed training, samples for all
39
+ layers are computed at once for each logit_idx. Logits are parameters
40
+ learnt independent of each other.
41
+
42
+ Args:
43
+ logit_idx: The index of logit parameters used for sampling.
44
+ """
45
+ assert logit_idx is not None
46
+ self.samples = self._gumbel_sigmoid(
47
+ self.layer_logits[logit_idx, :].detach()
48
+ if self.detach_grad
49
+ else self.layer_logits[logit_idx, :],
50
+ dim=-1,
51
+ tau=self.tau,
52
+ hard=self.hard_select,
53
+ )
54
+ self.layer_samples[logit_idx] = self.samples
55
+
56
+ def forward(self, i):
57
+ sample = self.samples[i]
58
+ return sample
59
+
60
+ def _gumbel_sigmoid(
61
+ self, logits, tau=1, hard=False, eps=1e-10, dim=-1, threshold=0.5
62
+ ):
63
+ # ~Gumbel(0,1)
64
+ gumbels1 = (
65
+ -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
66
+ .exponential_()
67
+ .log()
68
+ )
69
+ gumbels2 = (
70
+ -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format)
71
+ .exponential_()
72
+ .log()
73
+ )
74
+ # Difference of two gumbels because we apply a sigmoid
75
+ gumbels1 = (logits + gumbels1 - gumbels2) / tau
76
+ y_soft = gumbels1.sigmoid()
77
+ if hard:
78
+ # Straight through.
79
+ y_hard = torch.zeros_like(
80
+ logits, memory_format=torch.legacy_contiguous_format
81
+ ).masked_fill(y_soft > threshold, 1.0)
82
+ ret = y_hard - y_soft.detach() + y_soft
83
+ else:
84
+ # Reparametrization trick.
85
+ ret = y_soft
86
+ return ret
fairseq-0.10.2/examples/m2m_100/README.md ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Beyond English-Centric Multilingual Machine Translation
2
+
3
+ ## Introduction
4
+ In this work, we create a true Many-to-Many multilingual translation model that can translate directly between any pair of 100 languages. Our focus on non-English-Centric models brings gains of more than 10 BLEU when directly translating between non-English directions while performing competitively with the best single systems of WMT.
5
+
6
+ If you are new to using fairseq, read the following walkthrough. Otherwise, skip to the sections below.
7
+
8
+ 0. **Generation Data**
9
+
10
+ To download the generation data, follow the below commands. Note that all datasets need to be detokenized *before* applying SPM in the data preprocessing step. If you use these evaluation datasets, please cite their associated papers.
11
+ ```bash
12
+ # WMT - use sacrebleu, example here:
13
+ sacrebleu -t wmt14 -l fr-en --echo src > wmt.test.fr-en.fr
14
+ sacrebleu -t wmt14 -l fr-en --echo ref > wmt.test.fr-en.en
15
+
16
+ # WAT
17
+ wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2019.my-en.zip
18
+ unzip wat2019.my-en.zip
19
+
20
+ # FLORES
21
+ # download from: https://github.com/facebookresearch/flores
22
+
23
+ # TED - need to detokenize with Moses!
24
+ # from: https://github.com/neulab/word-embeddings-for-nmt
25
+ wget http://phontron.com/data/ted_talks.tar.gz
26
+
27
+ # Autshumato
28
+ # request to download: https://repo.sadilar.org/handle/20.500.12185/397
29
+
30
+ # Tatoeba Challenge
31
+ # available here: https://github.com/Helsinki-NLP/Tatoeba-Challenge
32
+ ```
33
+
34
+ 1. **Training Data**
35
+
36
+ To produce the training data, we use a combination of [CCMatrix](https://arxiv.org/abs/1911.04944) and [CCAligned](https://arxiv.org/abs/1911.06154). Check out the instructions [here](https://github.com/facebookresearch/LASER/tree/master/tasks/CCMatrix) to download the raw data.
37
+
38
+ 2. **Preprocess Data**
39
+
40
+ After downloading raw data, you will need to postprocess the data, then apply SPM, then binarize. Note that it is very important you run the postprocessing script, because this removes any instance of the evaluation data in the mined training data.
41
+
42
+ ```bash
43
+ # preprocess data
44
+
45
+ # remove sentences with more than 50% punctuation
46
+ python /path/to/fairseq/examples/m2m_100/process_data/remove_too_much_punc.py
47
+
48
+ # deduplicate training data
49
+ paste /path/to/datadir/train.$src /path/to/datadir/train.$tgt | awk '!x[$0]++' > /path/to/datadir/train.dedup
50
+ echo "keeping $(wc -l /path/to/datadir/train.dedup) bitext out of $(wc -l /path/to/datadir/train.$src)"
51
+ cut -f1 /path/to/datadir/train.dedup > /path/to/datadir/train.$src
52
+ cut -f2 /path/to/datadir/train.dedup > /path/to/datadir/train.$tgt
53
+
54
+ # remove all instances of evaluation data from the training data
55
+ python /path/to/fairseq/examples/m2m_100/process_data/dedup_data.py
56
+
57
+ # frequency cleaning
58
+ wget https://dl.fbaipublicfiles.com/m2m_100/histograms.tar.gz
59
+ tar -xvzf histograms.tar.gz
60
+ python /path/to/fairseq/examples/m2m_100/process_data/clean_histogram.py --src $src --tgt $tgt --src-file /path/to/source/file --tgt-file /path/to/output/file --src-output-file source_output.$src --tgt-output-file target_output.$tgt --histograms /path/to/histograms
61
+
62
+ # apply SPM
63
+ wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
64
+ python /path/to/fairseq/scripts/spm_encode.py \
65
+ --model spm.128k.model \
66
+ --output_format=piece \
67
+ --inputs=/path/to/input/file/here \
68
+ --outputs=/path/to/output/file/here
69
+
70
+ # length ratio cleaning
71
+ perl mosesdecoder/scripts/training/clean-corpus-n.perl --ratio 3 /path/to/training/data/train.spm.$src-$tgt $src $tgt /path/to/output/directory/train.spm.$src-$tgt 1 250
72
+
73
+ # binarize data
74
+ wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
75
+ fairseq-preprocess \
76
+ --source-lang $src --target-lang $tgt \
77
+ --testpref spm.$src.$tgt \
78
+ --thresholdsrc 0 --thresholdtgt 0 \
79
+ --destdir data_bin \
80
+ --srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
81
+ ```
82
+
83
+ 3. **Training Scripts**
84
+
85
+ To reproduce the training of our models, we train with fairseq-py's multilingual translation [task](https://github.com/pytorch/fairseq/tree/master/examples/multilingual). If you are interested in model parallel training, also check out [fairscale](https://github.com/facebookresearch/fairscale).
86
+
87
+ 4. **Generation**
88
+
89
+ To generate from our models, follow the the commands in the generation section below.
90
+
91
+
92
+ If you use any of the resources listed here, please cite:
93
+ ```bibtex
94
+ @article{fan2020beyond,
95
+ title={Beyond English-Centric Multilingual Machine Translation},
96
+ author={Fan, Angela and Bhosale, Shruti and Schwenk, Holger and Ma, Zhiyi and El-Kishky, Ahmed and Goyal, Siddharth and Baines, Mandeep and Celebi, Onur and Wenzek, Guillaume and Chaudhary, Vishrav and Goyal, Naman and Birch, Tom and Liptchinsky, Vitaliy and Edunov, Sergey and Grave, Edouard and Auli, Michael and Joulin, Armand},
97
+ journal={arXiv preprint},
98
+ year={2020}
99
+ }
100
+
101
+ @article{schwenk2019ccmatrix,
102
+ title={Ccmatrix: Mining billions of high-quality parallel sentences on the web},
103
+ author={Schwenk, Holger and Wenzek, Guillaume and Edunov, Sergey and Grave, Edouard and Joulin, Armand},
104
+ journal={arXiv preprint arXiv:1911.04944},
105
+ year={2019}
106
+ }
107
+
108
+ @article{el2019massive,
109
+ title={A Massive Collection of Cross-Lingual Web-Document Pairs},
110
+ author={El-Kishky, Ahmed and Chaudhary, Vishrav and Guzman, Francisco and Koehn, Philipp},
111
+ journal={arXiv preprint arXiv:1911.06154},
112
+ year={2019}
113
+ }
114
+ ```
115
+
116
+
117
+ ## Trained Models
118
+
119
+ Looking for other trained models? Check back soon.
120
+
121
+ Model | Description | Download
122
+ ---|---|---
123
+ `12b_last_checkpoint` | 12B parameter model trained on many-to-many training data for 100 languages | [12b_last_checkpoint](https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt)
124
+
125
+
126
+ ## SentencePiece Model
127
+
128
+ ```bash
129
+ wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
130
+ ```
131
+
132
+ ## Generation with M2M-100
133
+
134
+ ### Encode using our SentencePiece Model
135
+
136
+ Note: Install SentencePiece from [here](https://github.com/google/sentencepiece)
137
+
138
+ ```bash
139
+ fairseq=/path/to/fairseq
140
+ cd $fairseq
141
+ sacrebleu --echo src -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.de
142
+ sacrebleu --echo ref -l de-fr -t wmt19 | head -n 20 > raw_input.de-fr.fr
143
+ wget https://dl.fbaipublicfiles.com/m2m_100/spm.128k.model
144
+ for lang in de fr ; do
145
+ python scripts/spm_encode.py \
146
+ --model spm.128k.model \
147
+ --output_format=piece \
148
+ --inputs=raw_input.de-fr.${lang} \
149
+ --outputs=spm.de-fr.${lang}
150
+ done
151
+ ```
152
+
153
+ ### Binarization
154
+
155
+ ```bash
156
+ wget https://dl.fbaipublicfiles.com/m2m_100/data_dict.128k.txt
157
+ fairseq-preprocess \
158
+ --source-lang de --target-lang fr \
159
+ --testpref spm.de-fr \
160
+ --thresholdsrc 0 --thresholdtgt 0 \
161
+ --destdir data_bin \
162
+ --srcdict data_dict.128k.txt --tgtdict data_dict.128k.txt
163
+ ```
164
+
165
+ ### Generation on a V100 GPU
166
+
167
+ ```bash
168
+ wget https://dl.fbaipublicfiles.com/m2m_100/model_dict.128k.txt
169
+ wget https://dl.fbaipublicfiles.com/m2m_100/language_pairs.txt
170
+ wget https://dl.fbaipublicfiles.com/m2m_100/12b_last_checkpoint.pt
171
+ fairseq-generate \
172
+ data_bin \
173
+ --batch-size 1 \
174
+ --path 12b_last_checkpoint.pt \
175
+ --fixed-dictionary model_dict.128k.txt \
176
+ -s de -t fr \
177
+ --remove-bpe 'sentencepiece' \
178
+ --beam 5 \
179
+ --task translation_multi_simple_epoch \
180
+ --lang-pairs language_pairs.txt \
181
+ --decoder-langtok --encoder-langtok src \
182
+ --gen-subset test \
183
+ --fp16 \
184
+ --dataset-impl mmap \
185
+ --distributed-world-size 1 --distributed-no-spawn \
186
+ --pipeline-model-parallel \
187
+ --pipeline-chunks 1 \
188
+ --pipeline-encoder-balance '[26]' \
189
+ --pipeline-encoder-devices '[0]' \
190
+ --pipeline-decoder-balance '[1,24,1]' \
191
+ --pipeline-decoder-devices '[0,1,0]' > gen_out
192
+ ```
193
+ ## Evaluation with M2M-100
194
+
195
+ ### Tokenization
196
+
197
+ Note: Refer to tokenizers/README.md for more details on tokenization.
198
+
199
+ ```bash
200
+ cd ${fairseq}/examples/m2m_100
201
+ cat ${fairseq}/gen_out | grep -P "^H" | sort -V | cut -f 3- | sh tok.sh fr > hyp
202
+ cat ${fairseq}/raw_input.de-fr.fr | sh tok.sh fr > ref
203
+ ```
204
+
205
+ ### BLEU
206
+
207
+ ```bash
208
+ sacrebleu -tok 'none' ref < hyp
209
+ ```
fairseq-0.10.2/examples/m2m_100/install_dependecies.sh ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ CWD=`pwd`
9
+ INSTALL_PATH=$CWD/tokenizers/thirdparty
10
+
11
+ MOSES=$INSTALL_PATH/mosesdecoder
12
+ if [ ! -d $MOSES ]; then
13
+ echo 'Cloning Moses github repository (for tokenization scripts)...'
14
+ git clone https://github.com/moses-smt/mosesdecoder.git $MOSES
15
+ cd $MOSES
16
+ # To deal with differences in handling ' vs "
17
+ git checkout 03578921cc1a03402
18
+ cd -
19
+ fi
20
+
21
+ WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts
22
+ if [ ! -d $WMT16_SCRIPTS ]; then
23
+ echo 'Cloning Romanian tokenization scripts'
24
+ git clone https://github.com/rsennrich/wmt16-scripts.git $WMT16_SCRIPTS
25
+ fi
26
+
27
+ KYTEA=$INSTALL_PATH/kytea
28
+ if [ ! -f $KYTEA/bin/kytea ]; then
29
+ git clone https://github.com/neubig/kytea.git $KYTEA
30
+ cd $KYTEA
31
+ autoreconf -i
32
+ ./configure --prefix=`pwd`
33
+ make
34
+ make install
35
+ cd ..
36
+ fi
37
+
38
+ export MECAB=$INSTALL_PATH/mecab-0.996-ko-0.9.2
39
+ if [ ! -f $MECAB/bin/mecab ]; then
40
+ cd $INSTALL_PATH
41
+ curl -LO https://bitbucket.org/eunjeon/mecab-ko/downloads/mecab-0.996-ko-0.9.2.tar.gz
42
+ tar zxfv mecab-0.996-ko-0.9.2.tar.gz
43
+ cd mecab-0.996-ko-0.9.2/
44
+ ./configure --prefix=`pwd`
45
+ make
46
+ make install
47
+
48
+ cd ..
49
+ curl -LO https://bitbucket.org/eunjeon/mecab-ko-dic/downloads/mecab-ko-dic-2.1.1-20180720.tar.gz
50
+ tar zxfv mecab-ko-dic-2.1.1-20180720.tar.gz
51
+ cd mecab-ko-dic-2.1.1-20180720/
52
+ ./autogen.sh
53
+ ./configure --prefix=`pwd` --with-dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic --with-mecab-config=$MECAB/bin/mecab-config
54
+ make
55
+ sh -c 'echo "dicdir=$MECAB/lib/mecab/dic/mecab-ko-dic" > $MECAB/etc/mecabrc'
56
+ make install
57
+ cd $CWD
58
+ fi
59
+
60
+ INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources
61
+ if [ ! -d $INDIC_RESOURCES_PATH ]; then
62
+ echo 'Cloning indic_nlp_resources'
63
+ git clone https://github.com/anoopkunchukuttan/indic_nlp_resources.git $INDIC_RESOURCES_PATH
64
+ fi
65
+
66
+
67
+ if [ ! -f $INSTALL_PATH/seg_my.py ]; then
68
+ cd $INSTALL_PATH
69
+ wget http://lotus.kuee.kyoto-u.ac.jp/WAT/my-en-data/wat2020.my-en.zip
70
+ unzip wat2020.my-en.zip
71
+ # switch to python3
72
+ cat wat2020.my-en/myseg.py |sed 's/^sys.std/###sys.std/g' | sed 's/### sys/sys/g' | sed 's/unichr/chr/g' > seg_my.py
73
+ cd $CWD
74
+ fi
75
+
76
+
77
+ pip install pythainlp sacrebleu indic-nlp-library
78
+
fairseq-0.10.2/examples/m2m_100/process_data/remove_too_much_punc.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gzip
2
+ import argparse
3
+ from string import punctuation
4
+
5
+ def len_no_punc(s, punc):
6
+ return len([ch for ch in s if ch in punc])
7
+
8
+ def filter_overpunc(len_npunc, len_sen):
9
+ return len_npunc < 0.5*len_sen
10
+
11
+ def main(args):
12
+ punc = punctuation + "—|–"
13
+ print('Processing file {}'.format(args.input))
14
+ with gzip.open(args.input, 'rt', encoding=args.encoding) as tsv:
15
+ with open(args.bitext + '.' + args.src_lang, 'wt', encoding=args.encoding) as fsrc:
16
+ with open(args.bitext + '.' + args.tgt_lang, 'wt', encoding=args.encoding) as ftgt:
17
+ line = tsv.readline()
18
+ fields = line.split('\t')
19
+
20
+ src, tgt = fields[1], fields[2]
21
+
22
+ nchar_npunc_src = len_no_punc(src, punc)
23
+ nchar_npunc_tgt = len_no_punc(tgt, punc)
24
+
25
+ if filter_overpunc(nchar_npunc_src, len(src)) and filter_overpunc(nchar_npunc_tgt, len(tgt)):
26
+ fsrc.write(src.strip() + '\n')
27
+ ftgt.write(tgt.strip() + '\n')
28
+
29
+ if __name__ == '__main__':
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--input", required=True, type=str)
32
+ parser.add_argument('--encoding', default='utf-8', help='character encoding for input/output')
33
+ parser.add_argument('--bitext', type=str, required=True, help='language direction')
34
+ parser.add_argument('--src-lang', type=str, required=True, help='Source language')
35
+ parser.add_argument('--tgt-lang', type=str, required=True, help='Target language')
36
+ main(parser.parse_args())
fairseq-0.10.2/examples/m2m_100/tok.sh ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Copyright (c) 2019-present, Facebook, Inc.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+ #
8
+
9
+ set -e
10
+
11
+ TOKENIZERS_SCRIPTS=tokenizers
12
+ INSTALL_PATH=$TOKENIZERS_SCRIPTS/thirdparty
13
+
14
+ N_THREADS=8
15
+
16
+ lg=$1
17
+
18
+ MOSES=$INSTALL_PATH/mosesdecoder
19
+ REPLACE_UNICODE_PUNCT=$MOSES/scripts/tokenizer/replace-unicode-punctuation.perl
20
+ NORM_PUNC=$MOSES/scripts/tokenizer/normalize-punctuation.perl
21
+ REM_NON_PRINT_CHAR=$MOSES/scripts/tokenizer/remove-non-printing-char.perl
22
+ TOKENIZER=$MOSES/scripts/tokenizer/tokenizer.perl
23
+
24
+ # special tokenization for Romanian
25
+ WMT16_SCRIPTS=$INSTALL_PATH/wmt16-scripts
26
+
27
+ NORMALIZE_ROMANIAN=$WMT16_SCRIPTS/preprocess/normalise-romanian.py
28
+ REMOVE_DIACRITICS=$WMT16_SCRIPTS/preprocess/remove-diacritics.py
29
+
30
+ # Burmese
31
+ MY_SEGMENT=$INSTALL_PATH/seg_my.py
32
+
33
+ # Arabic
34
+ AR_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenizer_ar.sh
35
+
36
+ # Korean
37
+ KO_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ko.sh
38
+
39
+ # Japanese
40
+ JA_SEGMENT=$TOKENIZERS_SCRIPTS/seg_ja.sh
41
+
42
+ # Indic
43
+ IN_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_indic.py
44
+ INDIC_RESOURCES_PATH=$INSTALL_PATH/indic_nlp_resources
45
+
46
+ # Thai
47
+ THAI_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_thai.py
48
+
49
+ # Chinese
50
+ CHINESE_TOKENIZER=$TOKENIZERS_SCRIPTS/tokenize_zh.py
51
+
52
+ # Chinese
53
+ if [ "$lg" = "zh" ]; then
54
+ cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | python $CHINESE_TOKENIZER
55
+ # Thai
56
+ elif [ "$lg" = "th" ]; then
57
+ cat - | python $THAI_TOKENIZER
58
+ # Japanese
59
+ elif [ "$lg" = "ja" ]; then
60
+ cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | ${JA_SEGMENT}
61
+ # Korean
62
+ elif [ "$lg" = "ko" ]; then
63
+ cat - | $REM_NON_PRINT_CHAR | ${KO_SEGMENT}
64
+ # Romanian
65
+ elif [ "$lg" = "ro" ]; then
66
+ cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $NORMALIZE_ROMANIAN | $REMOVE_DIACRITICS | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
67
+ # Burmese
68
+ elif [ "$lg" = "my" ]; then
69
+ cat - | python ${MY_SEGMENT}
70
+ # Arabic
71
+ elif [ "$lg" = "ar" ]; then
72
+ cat - | ${AR_TOKENIZER}
73
+ # Indic
74
+ elif [ "$lg" = "ne" ]; then
75
+ cat - | python ${IN_TOKENIZER} $lg
76
+ elif [ "$lg" = "si" ]; then
77
+ cat - | python ${IN_TOKENIZER} $lg
78
+ elif [ "$lg" = "hi" ]; then
79
+ cat - | python ${IN_TOKENIZER} $lg
80
+ # other languages
81
+ else
82
+ cat - | $REPLACE_UNICODE_PUNCT | $NORM_PUNC -l $lg | $REM_NON_PRINT_CHAR | $TOKENIZER -no-escape -threads $N_THREADS -l $lg
83
+ fi
fairseq-0.10.2/examples/mbart/README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MBART: Multilingual Denoising Pre-training for Neural Machine Translation
2
+ [https://arxiv.org/abs/2001.08210]
3
+
4
+ ## Introduction
5
+
6
+ MBART is a sequence-to-sequence denoising auto-encoder pre-trained on large-scale monolingual corpora in many languages using the BART objective. mBART is one of the first methods for pre-training a complete sequence-to-sequence model by denoising full texts in multiple languages, while previous approaches have focused only on the encoder, decoder, or reconstructing parts of the text.
7
+
8
+ ## Pre-trained models
9
+
10
+ Model | Description | # params | Download
11
+ ---|---|---|---
12
+ `mbart.CC25` | mBART model with 12 encoder and decoder layers trained on 25 languages' monolingual corpus | 610M | [mbart.CC25.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz)
13
+ `mbart.ft.ro_en` | finetune mBART cc25 model on ro-en language pairs | 610M | [mbart.cc25.ft.enro.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz)
14
+
15
+ ## Results
16
+
17
+ **[WMT16 EN-RO](https://www.statmt.org/wmt16/translation-task.html)**
18
+
19
+ _(test set, no additional data used)_
20
+
21
+ Model | en-ro | ro-en
22
+ ---|---|---
23
+ `Random` | 34.3 | 34.0
24
+ `mbart.cc25` | 37.7 | 37.8
25
+ `mbart.enro.bilingual` | 38.5 | 38.5
26
+
27
+ ## BPE data
28
+ # download model
29
+ wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.CC25.tar.gz
30
+ tar -xzvf mbart.CC25.tar.gz
31
+ # bpe data
32
+ install SPM [here](https://github.com/google/sentencepiece)
33
+ ```bash
34
+ SPM=/path/to/sentencepiece/build/src/spm_encode
35
+ MODEL=sentence.bpe.model
36
+ ${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${SRC} > ${DATA}/${TRAIN}.spm.${SRC} &
37
+ ${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${TGT} > ${DATA}/${TRAIN}.spm.${TGT} &
38
+ ${SPM} --model=${MODEL} < ${DATA}/${VALID}.${SRC} > ${DATA}/${VALID}.spm.${SRC} &
39
+ ${SPM} --model=${MODEL} < ${DATA}/${VALID}.${TGT} > ${DATA}/${VALID}.spm.${TGT} &
40
+ ${SPM} --model=${MODEL} < ${DATA}/${TEST}.${SRC} > ${DATA}/${TEST}.spm.${SRC} &
41
+ ${SPM} --model=${MODEL} < ${DATA}/${TEST}.${TGT} > ${DATA}/${TEST}.spm.${TGT} &
42
+ ```
43
+
44
+ ## Preprocess data
45
+
46
+ ```bash
47
+ DICT=dict.txt
48
+ fairseq-preprocess \
49
+ --source-lang ${SRC} \
50
+ --target-lang ${TGT} \
51
+ --trainpref ${DATA}/${TRAIN}.spm \
52
+ --validpref ${DATA}/${VALID}.spm \
53
+ --testpref ${DATA}/${TEST}.spm \
54
+ --destdir ${DEST}/${NAME} \
55
+ --thresholdtgt 0 \
56
+ --thresholdsrc 0 \
57
+ --srcdict ${DICT} \
58
+ --tgtdict ${DICT} \
59
+ --workers 70
60
+ ```
61
+
62
+ ## Finetune on EN-RO
63
+ Finetune on mbart CC25
64
+
65
+ ```bash
66
+ PRETRAIN=mbart.cc25 # fix if you moved the downloaded checkpoint
67
+ langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
68
+
69
+ fairseq-train path_2_data \
70
+ --encoder-normalize-before --decoder-normalize-before \
71
+ --arch mbart_large --layernorm-embedding \
72
+ --task translation_from_pretrained_bart \
73
+ --source-lang en_XX --target-lang ro_RO \
74
+ --criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
75
+ --optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
76
+ --lr-scheduler polynomial_decay --lr 3e-05 --min-lr -1 --warmup-updates 2500 --total-num-update 40000 \
77
+ --dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
78
+ --max-tokens 1024 --update-freq 2 \
79
+ --save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
80
+ --seed 222 --log-format simple --log-interval 2 \
81
+ --restore-file $PRETRAIN \
82
+ --reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
83
+ --langs $langs \
84
+ --ddp-backend no_c10d
85
+ ```
86
+ ## Generate on EN-RO
87
+ Get sacrebleu on finetuned en-ro model
88
+
89
+ get tokenizer [here](https://github.com/rsennrich/wmt16-scripts)
90
+ ```bash
91
+ wget https://dl.fbaipublicfiles.com/fairseq/models/mbart/mbart.cc25.ft.enro.tar.gz
92
+ tar -xzvf mbart.cc25.ft.enro.tar.gz
93
+ ```
94
+
95
+ ```bash
96
+ model_dir=MBART_finetuned_enro # fix if you moved the checkpoint
97
+
98
+ fairseq-generate path_2_data \
99
+ --path $model_dir/model.pt \
100
+ --task translation_from_pretrained_bart \
101
+ --gen-subset test \
102
+ -t ro_RO -s en_XX \
103
+ --bpe 'sentencepiece' --sentencepiece-model $model_dir/sentence.bpe.model \
104
+ --sacrebleu --remove-bpe 'sentencepiece' \
105
+ --batch-size 32 --langs $langs > en_ro
106
+
107
+ cat en_ro | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.hyp
108
+ cat en_ro | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[ro_RO\]//g' |$TOKENIZER ro > en_ro.ref
109
+ sacrebleu -tok 'none' -s 'none' en_ro.ref < en_ro.hyp
110
+ ```
111
+
112
+ ## Citation
113
+
114
+ ```bibtex
115
+ @article{liu2020multilingual,
116
+ title={Multilingual Denoising Pre-training for Neural Machine Translation},
117
+ author={Yinhan Liu and Jiatao Gu and Naman Goyal and Xian Li and Sergey Edunov and Marjan Ghazvininejad and Mike Lewis and Luke Zettlemoyer},
118
+ year={2020},
119
+ eprint={2001.08210},
120
+ archivePrefix={arXiv},
121
+ primaryClass={cs.CL}
122
+ }
123
+ ```
fairseq-0.10.2/examples/roberta/README.glue.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning RoBERTa on GLUE tasks
2
+
3
+ ### 1) Download the data from GLUE website (https://gluebenchmark.com/tasks) using following commands:
4
+ ```bash
5
+ wget https://gist.githubusercontent.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e/raw/17b8dd0d724281ed7c3b2aeeda662b92809aadd5/download_glue_data.py
6
+ python download_glue_data.py --data_dir glue_data --tasks all
7
+ ```
8
+
9
+ ### 2) Preprocess GLUE task data:
10
+ ```bash
11
+ ./examples/roberta/preprocess_GLUE_tasks.sh glue_data <glue_task_name>
12
+ ```
13
+ `glue_task_name` is one of the following:
14
+ `{ALL, QQP, MNLI, QNLI, MRPC, RTE, STS-B, SST-2, CoLA}`
15
+ Use `ALL` for preprocessing all the glue tasks.
16
+
17
+ ### 3) Fine-tuning on GLUE task:
18
+ Example fine-tuning cmd for `RTE` task
19
+ ```bash
20
+ TOTAL_NUM_UPDATES=2036 # 10 epochs through RTE for bsz 16
21
+ WARMUP_UPDATES=122 # 6 percent of the number of updates
22
+ LR=2e-05 # Peak LR for polynomial LR scheduler.
23
+ NUM_CLASSES=2
24
+ MAX_SENTENCES=16 # Batch size.
25
+ ROBERTA_PATH=/path/to/roberta/model.pt
26
+
27
+ CUDA_VISIBLE_DEVICES=0 fairseq-train RTE-bin/ \
28
+ --restore-file $ROBERTA_PATH \
29
+ --max-positions 512 \
30
+ --batch-size $MAX_SENTENCES \
31
+ --max-tokens 4400 \
32
+ --task sentence_prediction \
33
+ --reset-optimizer --reset-dataloader --reset-meters \
34
+ --required-batch-size-multiple 1 \
35
+ --init-token 0 --separator-token 2 \
36
+ --arch roberta_large \
37
+ --criterion sentence_prediction \
38
+ --num-classes $NUM_CLASSES \
39
+ --dropout 0.1 --attention-dropout 0.1 \
40
+ --weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
41
+ --clip-norm 0.0 \
42
+ --lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
43
+ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
44
+ --max-epoch 10 \
45
+ --find-unused-parameters \
46
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric;
47
+ ```
48
+
49
+ For each of the GLUE task, you will need to use following cmd-line arguments:
50
+
51
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
52
+ ---|---|---|---|---|---|---|---|---
53
+ `--num-classes` | 3 | 2 | 2 | 2 | 2 | 2 | 2 | 1
54
+ `--lr` | 1e-5 | 1e-5 | 1e-5 | 2e-5 | 1e-5 | 1e-5 | 1e-5 | 2e-5
55
+ `--batch-size` | 32 | 32 | 32 | 16 | 32 | 16 | 16 | 16
56
+ `--total-num-update` | 123873 | 33112 | 113272 | 2036 | 20935 | 2296 | 5336 | 3598
57
+ `--warmup-updates` | 7432 | 1986 | 28318 | 122 | 1256 | 137 | 320 | 214
58
+
59
+ For `STS-B` additionally add `--regression-target --best-checkpoint-metric loss` and remove `--maximize-best-checkpoint-metric`.
60
+
61
+ **Note:**
62
+
63
+ a) `--total-num-updates` is used by `--polynomial_decay` scheduler and is calculated for `--max-epoch=10` and `--batch-size=16/32` depending on the task.
64
+
65
+ b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
66
+
67
+ c) All the settings in above table are suggested settings based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
68
+
69
+ ### Inference on GLUE task
70
+ After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet:
71
+
72
+ ```python
73
+ from fairseq.models.roberta import RobertaModel
74
+
75
+ roberta = RobertaModel.from_pretrained(
76
+ 'checkpoints/',
77
+ checkpoint_file='checkpoint_best.pt',
78
+ data_name_or_path='RTE-bin'
79
+ )
80
+
81
+ label_fn = lambda label: roberta.task.label_dictionary.string(
82
+ [label + roberta.task.label_dictionary.nspecial]
83
+ )
84
+ ncorrect, nsamples = 0, 0
85
+ roberta.cuda()
86
+ roberta.eval()
87
+ with open('glue_data/RTE/dev.tsv') as fin:
88
+ fin.readline()
89
+ for index, line in enumerate(fin):
90
+ tokens = line.strip().split('\t')
91
+ sent1, sent2, target = tokens[1], tokens[2], tokens[3]
92
+ tokens = roberta.encode(sent1, sent2)
93
+ prediction = roberta.predict('sentence_classification_head', tokens).argmax().item()
94
+ prediction_label = label_fn(prediction)
95
+ ncorrect += int(prediction_label == target)
96
+ nsamples += 1
97
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
98
+
99
+ ```
fairseq-0.10.2/examples/roberta/README.md ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # RoBERTa: A Robustly Optimized BERT Pretraining Approach
2
+
3
+ https://arxiv.org/abs/1907.11692
4
+
5
+ ## Introduction
6
+
7
+ RoBERTa iterates on BERT's pretraining procedure, including training the model longer, with bigger batches over more data; removing the next sentence prediction objective; training on longer sequences; and dynamically changing the masking pattern applied to the training data. See the associated paper for more details.
8
+
9
+ ### What's New:
10
+
11
+ - January 2020: Italian model (UmBERTo) is available from Musixmatch Research: [UmBERTo](https://github.com/musixmatchresearch/umberto).
12
+ - November 2019: French model (CamemBERT) is available: [CamemBERT](https://github.com/pytorch/fairseq/tree/master/examples/camembert).
13
+ - November 2019: Multilingual encoder (XLM-RoBERTa) is available: [XLM-R](https://github.com/pytorch/fairseq/tree/master/examples/xlmr).
14
+ - September 2019: TensorFlow and TPU support via the [transformers library](https://github.com/huggingface/transformers).
15
+ - August 2019: RoBERTa is now supported in the [pytorch-transformers library](https://github.com/huggingface/pytorch-transformers).
16
+ - August 2019: Added [tutorial for finetuning on WinoGrande](https://github.com/pytorch/fairseq/tree/master/examples/roberta/wsc#roberta-training-on-winogrande-dataset).
17
+ - August 2019: Added [tutorial for pretraining RoBERTa using your own data](README.pretraining.md).
18
+
19
+ ## Pre-trained models
20
+
21
+ Model | Description | # params | Download
22
+ ---|---|---|---
23
+ `roberta.base` | RoBERTa using the BERT-base architecture | 125M | [roberta.base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz)
24
+ `roberta.large` | RoBERTa using the BERT-large architecture | 355M | [roberta.large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz)
25
+ `roberta.large.mnli` | `roberta.large` finetuned on [MNLI](http://www.nyu.edu/projects/bowman/multinli) | 355M | [roberta.large.mnli.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.mnli.tar.gz)
26
+ `roberta.large.wsc` | `roberta.large` finetuned on [WSC](wsc/README.md) | 355M | [roberta.large.wsc.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.wsc.tar.gz)
27
+
28
+ ## Results
29
+
30
+ **[GLUE (Wang et al., 2019)](https://gluebenchmark.com/)**
31
+ _(dev set, single model, single-task finetuning)_
32
+
33
+ Model | MNLI | QNLI | QQP | RTE | SST-2 | MRPC | CoLA | STS-B
34
+ ---|---|---|---|---|---|---|---|---
35
+ `roberta.base` | 87.6 | 92.8 | 91.9 | 78.7 | 94.8 | 90.2 | 63.6 | 91.2
36
+ `roberta.large` | 90.2 | 94.7 | 92.2 | 86.6 | 96.4 | 90.9 | 68.0 | 92.4
37
+ `roberta.large.mnli` | 90.2 | - | - | - | - | - | - | -
38
+
39
+ **[SuperGLUE (Wang et al., 2019)](https://super.gluebenchmark.com/)**
40
+ _(dev set, single model, single-task finetuning)_
41
+
42
+ Model | BoolQ | CB | COPA | MultiRC | RTE | WiC | WSC
43
+ ---|---|---|---|---|---|---|---
44
+ `roberta.large` | 86.9 | 98.2 | 94.0 | 85.7 | 89.5 | 75.6 | -
45
+ `roberta.large.wsc` | - | - | - | - | - | - | 91.3
46
+
47
+ **[SQuAD (Rajpurkar et al., 2018)](https://rajpurkar.github.io/SQuAD-explorer/)**
48
+ _(dev set, no additional data used)_
49
+
50
+ Model | SQuAD 1.1 EM/F1 | SQuAD 2.0 EM/F1
51
+ ---|---|---
52
+ `roberta.large` | 88.9/94.6 | 86.5/89.4
53
+
54
+ **[RACE (Lai et al., 2017)](http://www.qizhexie.com/data/RACE_leaderboard.html)**
55
+ _(test set)_
56
+
57
+ Model | Accuracy | Middle | High
58
+ ---|---|---|---
59
+ `roberta.large` | 83.2 | 86.5 | 81.3
60
+
61
+ **[HellaSwag (Zellers et al., 2019)](https://rowanzellers.com/hellaswag/)**
62
+ _(test set)_
63
+
64
+ Model | Overall | In-domain | Zero-shot | ActivityNet | WikiHow
65
+ ---|---|---|---|---|---
66
+ `roberta.large` | 85.2 | 87.3 | 83.1 | 74.6 | 90.9
67
+
68
+ **[Commonsense QA (Talmor et al., 2019)](https://www.tau-nlp.org/commonsenseqa)**
69
+ _(test set)_
70
+
71
+ Model | Accuracy
72
+ ---|---
73
+ `roberta.large` (single model) | 72.1
74
+ `roberta.large` (ensemble) | 72.5
75
+
76
+ **[Winogrande (Sakaguchi et al., 2019)](https://arxiv.org/abs/1907.10641)**
77
+ _(test set)_
78
+
79
+ Model | Accuracy
80
+ ---|---
81
+ `roberta.large` | 78.1
82
+
83
+ **[XNLI (Conneau et al., 2018)](https://arxiv.org/abs/1809.05053)**
84
+ _(TRANSLATE-TEST)_
85
+
86
+ Model | en | fr | es | de | el | bg | ru | tr | ar | vi | th | zh | hi | sw | ur
87
+ ---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---
88
+ `roberta.large.mnli` | 91.3 | 82.91 | 84.27 | 81.24 | 81.74 | 83.13 | 78.28 | 76.79 | 76.64 | 74.17 | 74.05 | 77.5 | 70.9 | 66.65 | 66.81
89
+
90
+ ## Example usage
91
+
92
+ ##### Load RoBERTa from torch.hub (PyTorch >= 1.1):
93
+ ```python
94
+ import torch
95
+ roberta = torch.hub.load('pytorch/fairseq', 'roberta.large')
96
+ roberta.eval() # disable dropout (or leave in train mode to finetune)
97
+ ```
98
+
99
+ ##### Load RoBERTa (for PyTorch 1.0 or custom models):
100
+ ```python
101
+ # Download roberta.large model
102
+ wget https://dl.fbaipublicfiles.com/fairseq/models/roberta.large.tar.gz
103
+ tar -xzvf roberta.large.tar.gz
104
+
105
+ # Load the model in fairseq
106
+ from fairseq.models.roberta import RobertaModel
107
+ roberta = RobertaModel.from_pretrained('/path/to/roberta.large', checkpoint_file='model.pt')
108
+ roberta.eval() # disable dropout (or leave in train mode to finetune)
109
+ ```
110
+
111
+ ##### Apply Byte-Pair Encoding (BPE) to input text:
112
+ ```python
113
+ tokens = roberta.encode('Hello world!')
114
+ assert tokens.tolist() == [0, 31414, 232, 328, 2]
115
+ roberta.decode(tokens) # 'Hello world!'
116
+ ```
117
+
118
+ ##### Extract features from RoBERTa:
119
+ ```python
120
+ # Extract the last layer's features
121
+ last_layer_features = roberta.extract_features(tokens)
122
+ assert last_layer_features.size() == torch.Size([1, 5, 1024])
123
+
124
+ # Extract all layer's features (layer 0 is the embedding layer)
125
+ all_layers = roberta.extract_features(tokens, return_all_hiddens=True)
126
+ assert len(all_layers) == 25
127
+ assert torch.all(all_layers[-1] == last_layer_features)
128
+ ```
129
+
130
+ ##### Use RoBERTa for sentence-pair classification tasks:
131
+ ```python
132
+ # Download RoBERTa already finetuned for MNLI
133
+ roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
134
+ roberta.eval() # disable dropout for evaluation
135
+
136
+ # Encode a pair of sentences and make a prediction
137
+ tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.')
138
+ roberta.predict('mnli', tokens).argmax() # 0: contradiction
139
+
140
+ # Encode another pair of sentences
141
+ tokens = roberta.encode('Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.')
142
+ roberta.predict('mnli', tokens).argmax() # 2: entailment
143
+ ```
144
+
145
+ ##### Register a new (randomly initialized) classification head:
146
+ ```python
147
+ roberta.register_classification_head('new_task', num_classes=3)
148
+ logprobs = roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], grad_fn=<LogSoftmaxBackward>)
149
+ ```
150
+
151
+ ##### Batched prediction:
152
+ ```python
153
+ import torch
154
+ from fairseq.data.data_utils import collate_tokens
155
+
156
+ roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.mnli')
157
+ roberta.eval()
158
+
159
+ batch_of_pairs = [
160
+ ['Roberta is a heavily optimized version of BERT.', 'Roberta is not very optimized.'],
161
+ ['Roberta is a heavily optimized version of BERT.', 'Roberta is based on BERT.'],
162
+ ['potatoes are awesome.', 'I like to run.'],
163
+ ['Mars is very far from earth.', 'Mars is very close.'],
164
+ ]
165
+
166
+ batch = collate_tokens(
167
+ [roberta.encode(pair[0], pair[1]) for pair in batch_of_pairs], pad_idx=1
168
+ )
169
+
170
+ logprobs = roberta.predict('mnli', batch)
171
+ print(logprobs.argmax(dim=1))
172
+ # tensor([0, 2, 1, 0])
173
+ ```
174
+
175
+ ##### Using the GPU:
176
+ ```python
177
+ roberta.cuda()
178
+ roberta.predict('new_task', tokens) # tensor([[-1.1050, -1.0672, -1.1245]], device='cuda:0', grad_fn=<LogSoftmaxBackward>)
179
+ ```
180
+
181
+ ## Advanced usage
182
+
183
+ #### Filling masks:
184
+
185
+ RoBERTa can be used to fill `<mask>` tokens in the input. Some examples from the
186
+ [Natural Questions dataset](https://ai.google.com/research/NaturalQuestions/):
187
+ ```python
188
+ roberta.fill_mask('The first Star wars movie came out in <mask>', topk=3)
189
+ # [('The first Star wars movie came out in 1977', 0.9504708051681519, ' 1977'), ('The first Star wars movie came out in 1978', 0.009986862540245056, ' 1978'), ('The first Star wars movie came out in 1979', 0.009574787691235542, ' 1979')]
190
+
191
+ roberta.fill_mask('Vikram samvat calender is official in <mask>', topk=3)
192
+ # [('Vikram samvat calender is official in India', 0.21878819167613983, ' India'), ('Vikram samvat calender is official in Delhi', 0.08547237515449524, ' Delhi'), ('Vikram samvat calender is official in Gujarat', 0.07556215673685074, ' Gujarat')]
193
+
194
+ roberta.fill_mask('<mask> is the common currency of the European Union', topk=3)
195
+ # [('Euro is the common currency of the European Union', 0.9456493854522705, 'Euro'), ('euro is the common currency of the European Union', 0.025748178362846375, 'euro'), ('€ is the common currency of the European Union', 0.011183084920048714, '€')]
196
+ ```
197
+
198
+ #### Pronoun disambiguation (Winograd Schema Challenge):
199
+
200
+ RoBERTa can be used to disambiguate pronouns. First install spaCy and download the English-language model:
201
+ ```bash
202
+ pip install spacy
203
+ python -m spacy download en_core_web_lg
204
+ ```
205
+
206
+ Next load the `roberta.large.wsc` model and call the `disambiguate_pronoun`
207
+ function. The pronoun should be surrounded by square brackets (`[]`) and the
208
+ query referent surrounded by underscores (`_`), or left blank to return the
209
+ predicted candidate text directly:
210
+ ```python
211
+ roberta = torch.hub.load('pytorch/fairseq', 'roberta.large.wsc', user_dir='examples/roberta/wsc')
212
+ roberta.cuda() # use the GPU (optional)
213
+
214
+ roberta.disambiguate_pronoun('The _trophy_ would not fit in the brown suitcase because [it] was too big.')
215
+ # True
216
+ roberta.disambiguate_pronoun('The trophy would not fit in the brown _suitcase_ because [it] was too big.')
217
+ # False
218
+
219
+ roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] feared violence.')
220
+ # 'The city councilmen'
221
+ roberta.disambiguate_pronoun('The city councilmen refused the demonstrators a permit because [they] advocated violence.')
222
+ # 'demonstrators'
223
+ ```
224
+
225
+ See the [RoBERTA Winograd Schema Challenge (WSC) README](wsc/README.md) for more details on how to train this model.
226
+
227
+ #### Extract features aligned to words:
228
+
229
+ By default RoBERTa outputs one feature vector per BPE token. You can instead
230
+ realign the features to match [spaCy's word-level tokenization](https://spacy.io/usage/linguistic-features#tokenization)
231
+ with the `extract_features_aligned_to_words` method. This will compute a
232
+ weighted average of the BPE-level features for each word and expose them in
233
+ spaCy's `Token.vector` attribute:
234
+ ```python
235
+ doc = roberta.extract_features_aligned_to_words('I said, "hello RoBERTa."')
236
+ assert len(doc) == 10
237
+ for tok in doc:
238
+ print('{:10}{} (...)'.format(str(tok), tok.vector[:5]))
239
+ # <s> tensor([-0.1316, -0.0386, -0.0832, -0.0477, 0.1943], grad_fn=<SliceBackward>) (...)
240
+ # I tensor([ 0.0559, 0.1541, -0.4832, 0.0880, 0.0120], grad_fn=<SliceBackward>) (...)
241
+ # said tensor([-0.1565, -0.0069, -0.8915, 0.0501, -0.0647], grad_fn=<SliceBackward>) (...)
242
+ # , tensor([-0.1318, -0.0387, -0.0834, -0.0477, 0.1944], grad_fn=<SliceBackward>) (...)
243
+ # " tensor([-0.0486, 0.1818, -0.3946, -0.0553, 0.0981], grad_fn=<SliceBackward>) (...)
244
+ # hello tensor([ 0.0079, 0.1799, -0.6204, -0.0777, -0.0923], grad_fn=<SliceBackward>) (...)
245
+ # RoBERTa tensor([-0.2339, -0.1184, -0.7343, -0.0492, 0.5829], grad_fn=<SliceBackward>) (...)
246
+ # . tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
247
+ # " tensor([-0.1341, -0.1203, -0.1012, -0.0621, 0.1892], grad_fn=<SliceBackward>) (...)
248
+ # </s> tensor([-0.0930, -0.0392, -0.0821, 0.0158, 0.0649], grad_fn=<SliceBackward>) (...)
249
+ ```
250
+
251
+ #### Evaluating the `roberta.large.mnli` model:
252
+
253
+ Example python code snippet to evaluate accuracy on the MNLI `dev_matched` set.
254
+ ```python
255
+ label_map = {0: 'contradiction', 1: 'neutral', 2: 'entailment'}
256
+ ncorrect, nsamples = 0, 0
257
+ roberta.cuda()
258
+ roberta.eval()
259
+ with open('glue_data/MNLI/dev_matched.tsv') as fin:
260
+ fin.readline()
261
+ for index, line in enumerate(fin):
262
+ tokens = line.strip().split('\t')
263
+ sent1, sent2, target = tokens[8], tokens[9], tokens[-1]
264
+ tokens = roberta.encode(sent1, sent2)
265
+ prediction = roberta.predict('mnli', tokens).argmax().item()
266
+ prediction_label = label_map[prediction]
267
+ ncorrect += int(prediction_label == target)
268
+ nsamples += 1
269
+ print('| Accuracy: ', float(ncorrect)/float(nsamples))
270
+ # Expected output: 0.9060
271
+ ```
272
+
273
+ ## Finetuning
274
+
275
+ - [Finetuning on GLUE](README.glue.md)
276
+ - [Finetuning on custom classification tasks (e.g., IMDB)](README.custom_classification.md)
277
+ - [Finetuning on Winograd Schema Challenge (WSC)](wsc/README.md)
278
+ - [Finetuning on Commonsense QA (CQA)](commonsense_qa/README.md)
279
+ - Finetuning on SQuAD: coming soon
280
+
281
+ ## Pretraining using your own data
282
+
283
+ See the [tutorial for pretraining RoBERTa using your own data](README.pretraining.md).
284
+
285
+ ## Citation
286
+
287
+ ```bibtex
288
+ @article{liu2019roberta,
289
+ title = {RoBERTa: A Robustly Optimized BERT Pretraining Approach},
290
+ author = {Yinhan Liu and Myle Ott and Naman Goyal and Jingfei Du and
291
+ Mandar Joshi and Danqi Chen and Omer Levy and Mike Lewis and
292
+ Luke Zettlemoyer and Veselin Stoyanov},
293
+ journal={arXiv preprint arXiv:1907.11692},
294
+ year = {2019},
295
+ }
296
+ ```
fairseq-0.10.2/examples/roberta/README.pretraining.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pretraining RoBERTa using your own data
2
+
3
+ This tutorial will walk you through pretraining RoBERTa over your own data.
4
+
5
+ ### 1) Preprocess the data
6
+
7
+ Data should be preprocessed following the [language modeling format](/examples/language_model), i.e. each document should be separated by an empty line (only useful with `--sample-break-mode complete_doc`). Lines will be concatenated as a 1D text stream during training.
8
+
9
+ We'll use the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/)
10
+ to demonstrate how to preprocess raw text data with the GPT-2 BPE. Of course
11
+ this dataset is quite small, so the resulting pretrained model will perform
12
+ poorly, but it gives the general idea.
13
+
14
+ First download the dataset:
15
+ ```bash
16
+ wget https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-raw-v1.zip
17
+ unzip wikitext-103-raw-v1.zip
18
+ ```
19
+
20
+ Next encode it with the GPT-2 BPE:
21
+ ```bash
22
+ mkdir -p gpt2_bpe
23
+ wget -O gpt2_bpe/encoder.json https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
24
+ wget -O gpt2_bpe/vocab.bpe https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
25
+ for SPLIT in train valid test; do \
26
+ python -m examples.roberta.multiprocessing_bpe_encoder \
27
+ --encoder-json gpt2_bpe/encoder.json \
28
+ --vocab-bpe gpt2_bpe/vocab.bpe \
29
+ --inputs wikitext-103-raw/wiki.${SPLIT}.raw \
30
+ --outputs wikitext-103-raw/wiki.${SPLIT}.bpe \
31
+ --keep-empty \
32
+ --workers 60; \
33
+ done
34
+ ```
35
+
36
+ Finally preprocess/binarize the data using the GPT-2 fairseq dictionary:
37
+ ```bash
38
+ wget -O gpt2_bpe/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
39
+ fairseq-preprocess \
40
+ --only-source \
41
+ --srcdict gpt2_bpe/dict.txt \
42
+ --trainpref wikitext-103-raw/wiki.train.bpe \
43
+ --validpref wikitext-103-raw/wiki.valid.bpe \
44
+ --testpref wikitext-103-raw/wiki.test.bpe \
45
+ --destdir data-bin/wikitext-103 \
46
+ --workers 60
47
+ ```
48
+
49
+ ### 2) Train RoBERTa base
50
+ ```bash
51
+ TOTAL_UPDATES=125000 # Total number of training steps
52
+ WARMUP_UPDATES=10000 # Warmup the learning rate over this many updates
53
+ PEAK_LR=0.0005 # Peak learning rate, adjust as needed
54
+ TOKENS_PER_SAMPLE=512 # Max sequence length
55
+ MAX_POSITIONS=512 # Num. positional embeddings (usually same as above)
56
+ MAX_SENTENCES=16 # Number of sequences per batch (batch size)
57
+ UPDATE_FREQ=16 # Increase the batch size 16x
58
+
59
+ DATA_DIR=data-bin/wikitext-103
60
+
61
+ fairseq-train --fp16 $DATA_DIR \
62
+ --task masked_lm --criterion masked_lm \
63
+ --arch roberta_base --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \
64
+ --optimizer adam --adam-betas '(0.9,0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
65
+ --lr-scheduler polynomial_decay --lr $PEAK_LR --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
66
+ --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
67
+ --batch-size $MAX_SENTENCES --update-freq $UPDATE_FREQ \
68
+ --max-update $TOTAL_UPDATES --log-format simple --log-interval 1
69
+ ```
70
+
71
+ **Note:** You can optionally resume training the released RoBERTa base model by
72
+ adding `--restore-file /path/to/roberta.base/model.pt`.
73
+
74
+ **Note:** The above command assumes training on 8x32GB V100 GPUs. Each GPU uses
75
+ a batch size of 16 sequences (`$MAX_SENTENCES`) and accumulates gradients to
76
+ further increase the batch size by 16x (`$UPDATE_FREQ`), for a total batch size
77
+ of 2048 sequences. If you have fewer GPUs or GPUs with less memory you may need
78
+ to reduce `$MAX_SENTENCES` and increase `$UPDATE_FREQ` to compensate.
79
+ Alternatively if you have more GPUs you can decrease `$UPDATE_FREQ` accordingly
80
+ to increase training speed.
81
+
82
+ **Note:** The learning rate and batch size are tightly connected and need to be
83
+ adjusted together. We generally recommend increasing the learning rate as you
84
+ increase the batch size according to the following table (although it's also
85
+ dataset dependent, so don't rely on the following values too closely):
86
+
87
+ batch size | peak learning rate
88
+ ---|---
89
+ 256 | 0.0001
90
+ 2048 | 0.0005
91
+ 8192 | 0.0007
92
+
93
+ ### 3) Load your pretrained model
94
+ ```python
95
+ from fairseq.models.roberta import RobertaModel
96
+ roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'path/to/data')
97
+ assert isinstance(roberta.model, torch.nn.Module)
98
+ ```
fairseq-0.10.2/examples/roberta/README.race.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Finetuning RoBERTa on RACE tasks
2
+
3
+ ### 1) Download the data from RACE website (http://www.cs.cmu.edu/~glai1/data/race/)
4
+
5
+ ### 2) Preprocess RACE data:
6
+ ```bash
7
+ python ./examples/roberta/preprocess_RACE.py --input-dir <input-dir> --output-dir <extracted-data-dir>
8
+ ./examples/roberta/preprocess_RACE.sh <extracted-data-dir> <output-dir>
9
+ ```
10
+
11
+ ### 3) Fine-tuning on RACE:
12
+
13
+ ```bash
14
+ MAX_EPOCH=5 # Number of training epochs.
15
+ LR=1e-05 # Peak LR for fixed LR scheduler.
16
+ NUM_CLASSES=4
17
+ MAX_SENTENCES=1 # Batch size per GPU.
18
+ UPDATE_FREQ=8 # Accumulate gradients to simulate training on 8 GPUs.
19
+ DATA_DIR=/path/to/race-output-dir
20
+ ROBERTA_PATH=/path/to/roberta/model.pt
21
+
22
+ CUDA_VISIBLE_DEVICES=0,1 fairseq-train $DATA_DIR --ddp-backend=no_c10d \
23
+ --restore-file $ROBERTA_PATH \
24
+ --reset-optimizer --reset-dataloader --reset-meters \
25
+ --best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
26
+ --task sentence_ranking \
27
+ --num-classes $NUM_CLASSES \
28
+ --init-token 0 --separator-token 2 \
29
+ --max-option-length 128 \
30
+ --max-positions 512 \
31
+ --shorten-method "truncate" \
32
+ --arch roberta_large \
33
+ --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
34
+ --criterion sentence_ranking \
35
+ --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
36
+ --clip-norm 0.0 \
37
+ --lr-scheduler fixed --lr $LR \
38
+ --fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
39
+ --batch-size $MAX_SENTENCES \
40
+ --required-batch-size-multiple 1 \
41
+ --update-freq $UPDATE_FREQ \
42
+ --max-epoch $MAX_EPOCH
43
+ ```
44
+
45
+ **Note:**
46
+
47
+ a) As contexts in RACE are relatively long, we are using smaller batch size per GPU while increasing update-freq to achieve larger effective batch size.
48
+
49
+ b) Above cmd-args and hyperparams are tested on one Nvidia `V100` GPU with `32gb` of memory for each task. Depending on the GPU memory resources available to you, you can use increase `--update-freq` and reduce `--batch-size`.
50
+
51
+ c) The setting in above command is based on our hyperparam search within a fixed search space (for careful comparison across models). You might be able to find better metrics with wider hyperparam search.
52
+
53
+ ### 4) Evaluation:
54
+
55
+ ```
56
+ DATA_DIR=/path/to/race-output-dir # data directory used during training
57
+ MODEL_PATH=/path/to/checkpoint_best.pt # path to the finetuned model checkpoint
58
+ PREDS_OUT=preds.tsv # output file path to save prediction
59
+ TEST_SPLIT=test # can be test (Middle) or test1 (High)
60
+ fairseq-validate \
61
+ $DATA_DIR \
62
+ --valid-subset $TEST_SPLIT \
63
+ --path $MODEL_PATH \
64
+ --batch-size 1 \
65
+ --task sentence_ranking \
66
+ --criterion sentence_ranking \
67
+ --save-predictions $PREDS_OUT
68
+ ```
fairseq-0.10.2/examples/roberta/multiprocessing_bpe_encoder.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ import argparse
9
+ import contextlib
10
+ import sys
11
+ from collections import Counter
12
+ from multiprocessing import Pool
13
+
14
+ from fairseq.data.encoders.gpt2_bpe import get_encoder
15
+
16
+
17
+ def main():
18
+ """
19
+ Helper script to encode raw text with the GPT-2 BPE using multiple processes.
20
+
21
+ The encoder.json and vocab.bpe files can be obtained here:
22
+ - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json
23
+ - https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe
24
+ """
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument(
27
+ "--encoder-json",
28
+ help="path to encoder.json",
29
+ )
30
+ parser.add_argument(
31
+ "--vocab-bpe",
32
+ type=str,
33
+ help="path to vocab.bpe",
34
+ )
35
+ parser.add_argument(
36
+ "--inputs",
37
+ nargs="+",
38
+ default=["-"],
39
+ help="input files to filter/encode",
40
+ )
41
+ parser.add_argument(
42
+ "--outputs",
43
+ nargs="+",
44
+ default=["-"],
45
+ help="path to save encoded outputs",
46
+ )
47
+ parser.add_argument(
48
+ "--keep-empty",
49
+ action="store_true",
50
+ help="keep empty lines",
51
+ )
52
+ parser.add_argument("--workers", type=int, default=20)
53
+ args = parser.parse_args()
54
+
55
+ assert len(args.inputs) == len(
56
+ args.outputs
57
+ ), "number of input and output paths should match"
58
+
59
+ with contextlib.ExitStack() as stack:
60
+ inputs = [
61
+ stack.enter_context(open(input, "r", encoding="utf-8"))
62
+ if input != "-"
63
+ else sys.stdin
64
+ for input in args.inputs
65
+ ]
66
+ outputs = [
67
+ stack.enter_context(open(output, "w", encoding="utf-8"))
68
+ if output != "-"
69
+ else sys.stdout
70
+ for output in args.outputs
71
+ ]
72
+
73
+ encoder = MultiprocessingEncoder(args)
74
+ pool = Pool(args.workers, initializer=encoder.initializer)
75
+ encoded_lines = pool.imap(encoder.encode_lines, zip(*inputs), 100)
76
+
77
+ stats = Counter()
78
+ for i, (filt, enc_lines) in enumerate(encoded_lines, start=1):
79
+ if filt == "PASS":
80
+ for enc_line, output_h in zip(enc_lines, outputs):
81
+ print(enc_line, file=output_h)
82
+ else:
83
+ stats["num_filtered_" + filt] += 1
84
+ if i % 10000 == 0:
85
+ print("processed {} lines".format(i), file=sys.stderr)
86
+
87
+ for k, v in stats.most_common():
88
+ print("[{}] filtered {} lines".format(k, v), file=sys.stderr)
89
+
90
+
91
+ class MultiprocessingEncoder(object):
92
+ def __init__(self, args):
93
+ self.args = args
94
+
95
+ def initializer(self):
96
+ global bpe
97
+ bpe = get_encoder(self.args.encoder_json, self.args.vocab_bpe)
98
+
99
+ def encode(self, line):
100
+ global bpe
101
+ ids = bpe.encode(line)
102
+ return list(map(str, ids))
103
+
104
+ def decode(self, tokens):
105
+ global bpe
106
+ return bpe.decode(tokens)
107
+
108
+ def encode_lines(self, lines):
109
+ """
110
+ Encode a set of lines. All lines will be encoded together.
111
+ """
112
+ enc_lines = []
113
+ for line in lines:
114
+ line = line.strip()
115
+ if len(line) == 0 and not self.args.keep_empty:
116
+ return ["EMPTY", None]
117
+ tokens = self.encode(line)
118
+ enc_lines.append(" ".join(tokens))
119
+ return ["PASS", enc_lines]
120
+
121
+ def decode_lines(self, lines):
122
+ dec_lines = []
123
+ for line in lines:
124
+ tokens = map(int, line.strip().split())
125
+ dec_lines.append(self.decode(tokens))
126
+ return ["PASS", dec_lines]
127
+
128
+
129
+ if __name__ == "__main__":
130
+ main()
fairseq-0.10.2/examples/roberta/preprocess_GLUE_tasks.sh ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # raw glue data as downloaded by glue download script (https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e)
9
+ if [[ $# -ne 2 ]]; then
10
+ echo "Run as following:"
11
+ echo "./examples/roberta/preprocess_GLUE_tasks.sh <glud_data_folder> <task_name>"
12
+ exit 1
13
+ fi
14
+
15
+ GLUE_DATA_FOLDER=$1
16
+
17
+ # download bpe encoder.json, vocabulary and fairseq dictionary
18
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
19
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
20
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
21
+
22
+ TASKS=$2 # QQP
23
+
24
+ if [ "$TASKS" = "ALL" ]
25
+ then
26
+ TASKS="QQP MNLI QNLI MRPC RTE STS-B SST-2 CoLA"
27
+ fi
28
+
29
+ for TASK in $TASKS
30
+ do
31
+ echo "Preprocessing $TASK"
32
+
33
+ TASK_DATA_FOLDER="$GLUE_DATA_FOLDER/$TASK"
34
+ echo "Raw data as downloaded from glue website: $TASK_DATA_FOLDER"
35
+
36
+ SPLITS="train dev test"
37
+ INPUT_COUNT=2
38
+ if [ "$TASK" = "QQP" ]
39
+ then
40
+ INPUT_COLUMNS=( 4 5 )
41
+ TEST_INPUT_COLUMNS=( 2 3 )
42
+ LABEL_COLUMN=6
43
+ elif [ "$TASK" = "MNLI" ]
44
+ then
45
+ SPLITS="train dev_matched dev_mismatched test_matched test_mismatched"
46
+ INPUT_COLUMNS=( 9 10 )
47
+ TEST_INPUT_COLUMNS=( 9 10 )
48
+ DEV_LABEL_COLUMN=16
49
+ LABEL_COLUMN=12
50
+ elif [ "$TASK" = "QNLI" ]
51
+ then
52
+ INPUT_COLUMNS=( 2 3 )
53
+ TEST_INPUT_COLUMNS=( 2 3 )
54
+ LABEL_COLUMN=4
55
+ elif [ "$TASK" = "MRPC" ]
56
+ then
57
+ INPUT_COLUMNS=( 4 5 )
58
+ TEST_INPUT_COLUMNS=( 4 5 )
59
+ LABEL_COLUMN=1
60
+ elif [ "$TASK" = "RTE" ]
61
+ then
62
+ INPUT_COLUMNS=( 2 3 )
63
+ TEST_INPUT_COLUMNS=( 2 3 )
64
+ LABEL_COLUMN=4
65
+ elif [ "$TASK" = "STS-B" ]
66
+ then
67
+ INPUT_COLUMNS=( 8 9 )
68
+ TEST_INPUT_COLUMNS=( 8 9 )
69
+ LABEL_COLUMN=10
70
+ # Following are single sentence tasks.
71
+ elif [ "$TASK" = "SST-2" ]
72
+ then
73
+ INPUT_COLUMNS=( 1 )
74
+ TEST_INPUT_COLUMNS=( 2 )
75
+ LABEL_COLUMN=2
76
+ INPUT_COUNT=1
77
+ elif [ "$TASK" = "CoLA" ]
78
+ then
79
+ INPUT_COLUMNS=( 4 )
80
+ TEST_INPUT_COLUMNS=( 2 )
81
+ LABEL_COLUMN=2
82
+ INPUT_COUNT=1
83
+ fi
84
+
85
+ # Strip out header and filter lines that don't have expected number of fields.
86
+ rm -rf "$TASK_DATA_FOLDER/processed"
87
+ mkdir -p "$TASK_DATA_FOLDER/processed"
88
+ for SPLIT in $SPLITS
89
+ do
90
+ # CoLA train and dev doesn't have header.
91
+ if [[ ( "$TASK" = "CoLA") && ( "$SPLIT" != "test" ) ]]
92
+ then
93
+ cp "$TASK_DATA_FOLDER/$SPLIT.tsv" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
94
+ else
95
+ tail -n +2 "$TASK_DATA_FOLDER/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
96
+ fi
97
+
98
+ # Remove unformatted lines from train and dev files for QQP dataset.
99
+ if [[ ( "$TASK" = "QQP") && ( "$SPLIT" != "test" ) ]]
100
+ then
101
+ awk -F '\t' -v NUM_FIELDS=6 'NF==NUM_FIELDS{print}{}' "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" > "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
102
+ else
103
+ cp "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv";
104
+ fi
105
+ rm "$TASK_DATA_FOLDER/processed/$SPLIT.tsv.temp";
106
+ done
107
+
108
+ # Split into input0, input1 and label
109
+ for SPLIT in $SPLITS
110
+ do
111
+ for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
112
+ do
113
+ if [[ "$SPLIT" != test* ]]
114
+ then
115
+ COLUMN_NUMBER=${INPUT_COLUMNS[$INPUT_TYPE]}
116
+ else
117
+ COLUMN_NUMBER=${TEST_INPUT_COLUMNS[$INPUT_TYPE]}
118
+ fi
119
+ cut -f"$COLUMN_NUMBER" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.raw.input$INPUT_TYPE";
120
+ done
121
+
122
+ if [[ "$SPLIT" != test* ]]
123
+ then
124
+ if [ "$TASK" = "MNLI" ] && [ "$SPLIT" != "train" ]
125
+ then
126
+ cut -f"$DEV_LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
127
+ else
128
+ cut -f"$LABEL_COLUMN" "$TASK_DATA_FOLDER/processed/$SPLIT.tsv" > "$TASK_DATA_FOLDER/processed/$SPLIT.label";
129
+ fi
130
+ fi
131
+
132
+ # BPE encode.
133
+ for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
134
+ do
135
+ LANG="input$INPUT_TYPE"
136
+ echo "BPE encoding $SPLIT/$LANG"
137
+ python -m examples.roberta.multiprocessing_bpe_encoder \
138
+ --encoder-json encoder.json \
139
+ --vocab-bpe vocab.bpe \
140
+ --inputs "$TASK_DATA_FOLDER/processed/$SPLIT.raw.$LANG" \
141
+ --outputs "$TASK_DATA_FOLDER/processed/$SPLIT.$LANG" \
142
+ --workers 60 \
143
+ --keep-empty;
144
+ done
145
+ done
146
+
147
+ # Remove output directory.
148
+ rm -rf "$TASK-bin"
149
+
150
+ DEVPREF="$TASK_DATA_FOLDER/processed/dev.LANG"
151
+ TESTPREF="$TASK_DATA_FOLDER/processed/test.LANG"
152
+ if [ "$TASK" = "MNLI" ]
153
+ then
154
+ DEVPREF="$TASK_DATA_FOLDER/processed/dev_matched.LANG,$TASK_DATA_FOLDER/processed/dev_mismatched.LANG"
155
+ TESTPREF="$TASK_DATA_FOLDER/processed/test_matched.LANG,$TASK_DATA_FOLDER/processed/test_mismatched.LANG"
156
+ fi
157
+
158
+ # Run fairseq preprocessing:
159
+ for INPUT_TYPE in $(seq 0 $((INPUT_COUNT-1)))
160
+ do
161
+ LANG="input$INPUT_TYPE"
162
+ fairseq-preprocess \
163
+ --only-source \
164
+ --trainpref "$TASK_DATA_FOLDER/processed/train.$LANG" \
165
+ --validpref "${DEVPREF//LANG/$LANG}" \
166
+ --testpref "${TESTPREF//LANG/$LANG}" \
167
+ --destdir "$TASK-bin/$LANG" \
168
+ --workers 60 \
169
+ --srcdict dict.txt;
170
+ done
171
+ if [[ "$TASK" != "STS-B" ]]
172
+ then
173
+ fairseq-preprocess \
174
+ --only-source \
175
+ --trainpref "$TASK_DATA_FOLDER/processed/train.label" \
176
+ --validpref "${DEVPREF//LANG/label}" \
177
+ --destdir "$TASK-bin/label" \
178
+ --workers 60;
179
+ else
180
+ # For STS-B output range is converted to be between: [0.0, 1.0]
181
+ mkdir -p "$TASK-bin/label"
182
+ awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/train.label" > "$TASK-bin/label/train.label"
183
+ awk '{print $1 / 5.0 }' "$TASK_DATA_FOLDER/processed/dev.label" > "$TASK-bin/label/valid.label"
184
+ fi
185
+ done
fairseq-0.10.2/examples/roberta/preprocess_RACE.sh ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # data should be downloaded and processed with reprocess_RACE.py
9
+ if [[ $# -ne 2 ]]; then
10
+ echo "Run as following:"
11
+ echo "./examples/roberta/preprocess_RACE.sh <race_data_folder> <output_folder>"
12
+ exit 1
13
+ fi
14
+
15
+ RACE_DATA_FOLDER=$1
16
+ OUT_DATA_FOLDER=$2
17
+
18
+ # download bpe encoder.json, vocabulary and fairseq dictionary
19
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
20
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
21
+ wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
22
+
23
+ SPLITS="train dev test-middle test-high"
24
+ INPUT_TYPES="input0 input1 input2 input3 input4"
25
+ for INPUT_TYPE in $INPUT_TYPES
26
+ do
27
+ for SPLIT in $SPLITS
28
+ do
29
+ echo "BPE encoding $SPLIT/$INPUT_TYPE"
30
+ python -m examples.roberta.multiprocessing_bpe_encoder \
31
+ --encoder-json encoder.json \
32
+ --vocab-bpe vocab.bpe \
33
+ --inputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE" \
34
+ --outputs "$RACE_DATA_FOLDER/$SPLIT.$INPUT_TYPE.bpe" \
35
+ --workers 10 \
36
+ --keep-empty;
37
+
38
+ done
39
+ done
40
+
41
+ for INPUT_TYPE in $INPUT_TYPES
42
+ do
43
+ LANG="input$INPUT_TYPE"
44
+ fairseq-preprocess \
45
+ --only-source \
46
+ --trainpref "$RACE_DATA_FOLDER/train.$INPUT_TYPE.bpe" \
47
+ --validpref "$RACE_DATA_FOLDER/dev.$INPUT_TYPE.bpe" \
48
+ --testpref "$RACE_DATA_FOLDER/test-middle.$INPUT_TYPE.bpe,$RACE_DATA_FOLDER/test-high.$INPUT_TYPE.bpe" \
49
+ --destdir "$OUT_DATA_FOLDER/$INPUT_TYPE" \
50
+ --workers 10 \
51
+ --srcdict dict.txt;
52
+ done
53
+
54
+ rm -rf "$OUT_DATA_FOLDER/label"
55
+ mkdir -p "$OUT_DATA_FOLDER/label"
56
+ cp "$RACE_DATA_FOLDER/train.label" "$OUT_DATA_FOLDER/label/"
57
+ cp "$RACE_DATA_FOLDER/dev.label" "$OUT_DATA_FOLDER/label/valid.label"
58
+ cp "$RACE_DATA_FOLDER/test-middle.label" "$OUT_DATA_FOLDER/label/test.label"
59
+ cp "$RACE_DATA_FOLDER/test-high.label" "$OUT_DATA_FOLDER/label/test1.label"
fairseq-0.10.2/examples/speech_recognition/README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Speech Recognition
2
+ `examples/speech_recognition` is implementing ASR task in Fairseq, along with needed features, datasets, models and loss functions to train and infer model described in [Transformers with convolutional context for ASR (Abdelrahman Mohamed et al., 2019)](https://arxiv.org/abs/1904.11660).
3
+
4
+
5
+ ## Additional dependencies
6
+ On top of main fairseq dependencies there are couple more additional requirements.
7
+
8
+ 1) Please follow the instructions to install [torchaudio](https://github.com/pytorch/audio). This is required to compute audio fbank features.
9
+ 2) [Sclite](http://www1.icsi.berkeley.edu/Speech/docs/sctk-1.2/sclite.htm#sclite_name_0) is used to measure WER. Sclite can be downloaded and installed from source from sctk package [here](http://www.openslr.org/4/). Training and inference doesn't require Sclite dependency.
10
+ 3) [sentencepiece](https://github.com/google/sentencepiece) is required in order to create dataset with word-piece targets.
11
+
12
+ ## Preparing librispeech data
13
+ ```
14
+ ./examples/speech_recognition/datasets/prepare-librispeech.sh $DIR_TO_SAVE_RAW_DATA $DIR_FOR_PREPROCESSED_DATA
15
+ ```
16
+
17
+ ## Training librispeech data
18
+ ```
19
+ python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 80 --task speech_recognition --arch vggtransformer_2 --optimizer adadelta --lr 1.0 --adadelta-eps 1e-8 --adadelta-rho 0.95 --clip-norm 10.0 --max-tokens 5000 --log-format json --log-interval 1 --criterion cross_entropy_acc --user-dir examples/speech_recognition/
20
+ ```
21
+
22
+ ## Inference for librispeech
23
+ `$SET` can be `test_clean` or `test_other`
24
+ Any checkpoint in `$MODEL_PATH` can be selected. In this example we are working with `checkpoint_last.pt`
25
+ ```
26
+ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --max-tokens 25000 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --beam 20 --results-path $RES_DIR --batch-size 40 --gen-subset $SET --user-dir examples/speech_recognition/
27
+ ```
28
+
29
+ ## Inference for librispeech
30
+ ```
31
+ sclite -r ${RES_DIR}/ref.word-checkpoint_last.pt-${SET}.txt -h ${RES_DIR}/hypo.word-checkpoint_last.pt-${SET}.txt -i rm -o all stdout > $RES_REPORT
32
+ ```
33
+ `Sum/Avg` row from first table of the report has WER
34
+
35
+ ## Using wav2letter components
36
+ [wav2letter](https://github.com/facebookresearch/wav2letter) now has integration with fairseq. Currently this includes:
37
+
38
+ * AutoSegmentationCriterion (ASG)
39
+ * wav2letter-style Conv/GLU model
40
+ * wav2letter's beam search decoder
41
+
42
+ To use these, follow the instructions on [this page](https://github.com/facebookresearch/wav2letter/tree/master/bindings/python) to install python bindings. Please note that python bindings are for a *subset* of wav2letter and don't require its full dependencies (notably, `flashlight` and `ArrayFire` are *not* required).
43
+
44
+ To quickly summarize the instructions: first, install [CUDA](https://developer.nvidia.com/cuda-downloads). Then follow these steps:
45
+ ```
46
+ # additional prerequisites - use equivalents for your distro
47
+ sudo apt-get install build-essential cmake libatlas-base-dev libfftw3-dev liblzma-dev libbz2-dev libzstd-dev
48
+ # install KenLM from source
49
+ git clone https://github.com/kpu/kenlm.git
50
+ cd kenlm
51
+ mkdir -p build && cd build
52
+ cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_POSITION_INDEPENDENT_CODE=ON
53
+ make -j16
54
+ cd ..
55
+ export KENLM_ROOT_DIR=$(pwd)
56
+ cd ..
57
+ # install wav2letter python bindings
58
+ git clone https://github.com/facebookresearch/wav2letter.git
59
+ cd wav2letter/bindings/python
60
+ # make sure your python environment is active at this point
61
+ pip install torch packaging
62
+ pip install -e .
63
+ # try some examples to verify installation succeeded
64
+ python ./examples/criterion_example.py
65
+ python ./examples/decoder_example.py ../../src/decoder/test
66
+ python ./examples/feature_example.py ../../src/feature/test/data
67
+ ```
68
+
69
+ ## Training librispeech data (wav2letter style, Conv/GLU + ASG loss)
70
+ Training command:
71
+ ```
72
+ python train.py $DIR_FOR_PREPROCESSED_DATA --save-dir $MODEL_PATH --max-epoch 100 --task speech_recognition --arch w2l_conv_glu_enc --batch-size 4 --optimizer sgd --lr 0.3,0.8 --momentum 0.8 --clip-norm 0.2 --max-tokens 50000 --log-format json --log-interval 100 --num-workers 0 --sentence-avg --criterion asg_loss --asg-transitions-init 5 --max-replabel 2 --linseg-updates 8789 --user-dir examples/speech_recognition
73
+ ```
74
+
75
+ Note that ASG loss currently doesn't do well with word-pieces. You should prepare a dataset with character targets by setting `nbpe=31` in `prepare-librispeech.sh`.
76
+
77
+ ## Inference for librispeech (wav2letter decoder, n-gram LM)
78
+ Inference command:
79
+ ```
80
+ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder kenlm --kenlm-model $KENLM_MODEL_PATH --lexicon $LEXICON_PATH --beam 200 --beam-threshold 15 --lm-weight 1.5 --word-score 1.5 --sil-weight -0.3 --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
81
+ ```
82
+
83
+ `$KENLM_MODEL_PATH` should be a standard n-gram language model file. `$LEXICON_PATH` should be a wav2letter-style lexicon (list of known words and their spellings). For ASG inference, a lexicon line should look like this (note the repetition labels):
84
+ ```
85
+ doorbell D O 1 R B E L 1 ▁
86
+ ```
87
+ For CTC inference with word-pieces, repetition labels are not used and the lexicon should have most common spellings for each word (one can use sentencepiece's `NBestEncodeAsPieces` for this):
88
+ ```
89
+ doorbell ▁DOOR BE LL
90
+ doorbell ▁DOOR B E LL
91
+ doorbell ▁DO OR BE LL
92
+ doorbell ▁DOOR B EL L
93
+ doorbell ▁DOOR BE L L
94
+ doorbell ▁DO OR B E LL
95
+ doorbell ▁DOOR B E L L
96
+ doorbell ▁DO OR B EL L
97
+ doorbell ▁DO O R BE LL
98
+ doorbell ▁DO OR BE L L
99
+ ```
100
+ Lowercase vs. uppercase matters: the *word* should match the case of the n-gram language model (i.e. `$KENLM_MODEL_PATH`), while the *spelling* should match the case of the token dictionary (i.e. `$DIR_FOR_PREPROCESSED_DATA/dict.txt`).
101
+
102
+ ## Inference for librispeech (wav2letter decoder, viterbi only)
103
+ Inference command:
104
+ ```
105
+ python examples/speech_recognition/infer.py $DIR_FOR_PREPROCESSED_DATA --task speech_recognition --seed 1 --nbest 1 --path $MODEL_PATH/checkpoint_last.pt --gen-subset $SET --results-path $RES_DIR --w2l-decoder viterbi --criterion asg_loss --max-replabel 2 --user-dir examples/speech_recognition
106
+ ```
fairseq-0.10.2/examples/speech_recognition/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import criterions, models, tasks # noqa
fairseq-0.10.2/examples/speech_recognition/criterions/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+
4
+
5
+ # ASG loss requires wav2letter
6
+ files_to_skip = set()
7
+ try:
8
+ import wav2letter
9
+ except ImportError:
10
+ files_to_skip.add("ASG_loss.py")
11
+
12
+ for file in os.listdir(os.path.dirname(__file__)):
13
+ if file.endswith(".py") and not file.startswith("_") and file not in files_to_skip:
14
+ criterion_name = file[: file.find(".py")]
15
+ importlib.import_module(
16
+ "examples.speech_recognition.criterions." + criterion_name
17
+ )
fairseq-0.10.2/examples/speech_recognition/criterions/cross_entropy_acc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from __future__ import absolute_import, division, print_function, unicode_literals
7
+
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from fairseq import utils
14
+ from fairseq.criterions import FairseqCriterion, register_criterion
15
+
16
+
17
+ @register_criterion("cross_entropy_acc")
18
+ class CrossEntropyWithAccCriterion(FairseqCriterion):
19
+ def __init__(self, task, sentence_avg):
20
+ super().__init__(task)
21
+ self.sentence_avg = sentence_avg
22
+
23
+ def compute_loss(self, model, net_output, target, reduction, log_probs):
24
+ # N, T -> N * T
25
+ target = target.view(-1)
26
+ lprobs = model.get_normalized_probs(net_output, log_probs=log_probs)
27
+ if not hasattr(lprobs, "batch_first"):
28
+ logging.warning(
29
+ "ERROR: we need to know whether "
30
+ "batch first for the net output; "
31
+ "you need to set batch_first attribute for the return value of "
32
+ "model.get_normalized_probs. Now, we assume this is true, but "
33
+ "in the future, we will raise exception instead. "
34
+ )
35
+ batch_first = getattr(lprobs, "batch_first", True)
36
+ if not batch_first:
37
+ lprobs = lprobs.transpose(0, 1)
38
+
39
+ # N, T, D -> N * T, D
40
+ lprobs = lprobs.view(-1, lprobs.size(-1))
41
+ loss = F.nll_loss(
42
+ lprobs, target, ignore_index=self.padding_idx, reduction=reduction
43
+ )
44
+ return lprobs, loss
45
+
46
+ def get_logging_output(self, sample, target, lprobs, loss):
47
+ target = target.view(-1)
48
+ mask = target != self.padding_idx
49
+ correct = torch.sum(
50
+ lprobs.argmax(1).masked_select(mask) == target.masked_select(mask)
51
+ )
52
+ total = torch.sum(mask)
53
+ sample_size = (
54
+ sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
55
+ )
56
+
57
+ logging_output = {
58
+ "loss": utils.item(loss.data), # * sample['ntokens'],
59
+ "ntokens": sample["ntokens"],
60
+ "nsentences": sample["target"].size(0),
61
+ "sample_size": sample_size,
62
+ "correct": utils.item(correct.data),
63
+ "total": utils.item(total.data),
64
+ "nframes": torch.sum(sample["net_input"]["src_lengths"]).item(),
65
+ }
66
+
67
+ return sample_size, logging_output
68
+
69
+ def forward(self, model, sample, reduction="sum", log_probs=True):
70
+ """Computes the cross entropy with accuracy metric for the given sample.
71
+
72
+ This is similar to CrossEntropyCriterion in fairseq, but also
73
+ computes accuracy metrics as part of logging
74
+
75
+ Args:
76
+ logprobs (Torch.tensor) of shape N, T, D i.e.
77
+ batchsize, timesteps, dimensions
78
+ targets (Torch.tensor) of shape N, T i.e batchsize, timesteps
79
+
80
+ Returns:
81
+ tuple: With three elements:
82
+ 1) the loss
83
+ 2) the sample size, which is used as the denominator for the gradient
84
+ 3) logging outputs to display while training
85
+
86
+ TODO:
87
+ * Currently this Criterion will only work with LSTMEncoderModels or
88
+ FairseqModels which have decoder, or Models which return TorchTensor
89
+ as net_output.
90
+ We need to make a change to support all FairseqEncoder models.
91
+ """
92
+ net_output = model(**sample["net_input"])
93
+ target = model.get_targets(sample, net_output)
94
+ lprobs, loss = self.compute_loss(
95
+ model, net_output, target, reduction, log_probs
96
+ )
97
+ sample_size, logging_output = self.get_logging_output(
98
+ sample, target, lprobs, loss
99
+ )
100
+ return loss, sample_size, logging_output
101
+
102
+ @staticmethod
103
+ def aggregate_logging_outputs(logging_outputs):
104
+ """Aggregate logging outputs from data parallel training."""
105
+ correct_sum = sum(log.get("correct", 0) for log in logging_outputs)
106
+ total_sum = sum(log.get("total", 0) for log in logging_outputs)
107
+ loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
108
+ ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
109
+ nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
110
+ sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
111
+ nframes = sum(log.get("nframes", 0) for log in logging_outputs)
112
+ agg_output = {
113
+ "loss": loss_sum / sample_size / math.log(2) if sample_size > 0 else 0.0,
114
+ # if args.sentence_avg, then sample_size is nsentences, then loss
115
+ # is per-sentence loss; else sample_size is ntokens, the loss
116
+ # becomes per-output token loss
117
+ "ntokens": ntokens,
118
+ "nsentences": nsentences,
119
+ "nframes": nframes,
120
+ "sample_size": sample_size,
121
+ "acc": correct_sum * 100.0 / total_sum if total_sum > 0 else 0.0,
122
+ "correct": correct_sum,
123
+ "total": total_sum,
124
+ # total is the number of validate tokens
125
+ }
126
+ if sample_size != ntokens:
127
+ agg_output["nll_loss"] = loss_sum / ntokens / math.log(2)
128
+ # loss: per output token loss
129
+ # nll_loss: per sentence loss
130
+ return agg_output
fairseq-0.10.2/examples/speech_recognition/datasets/asr_prep_json.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from __future__ import absolute_import, division, print_function, unicode_literals
8
+
9
+ import argparse
10
+ import concurrent.futures
11
+ import json
12
+ import multiprocessing
13
+ import os
14
+ from collections import namedtuple
15
+ from itertools import chain
16
+
17
+ import sentencepiece as spm
18
+ from fairseq.data import Dictionary
19
+
20
+
21
+ MILLISECONDS_TO_SECONDS = 0.001
22
+
23
+
24
+ def process_sample(aud_path, lable, utt_id, sp, tgt_dict):
25
+ import torchaudio
26
+
27
+ input = {}
28
+ output = {}
29
+ si, ei = torchaudio.info(aud_path)
30
+ input["length_ms"] = int(
31
+ si.length / si.channels / si.rate / MILLISECONDS_TO_SECONDS
32
+ )
33
+ input["path"] = aud_path
34
+
35
+ token = " ".join(sp.EncodeAsPieces(lable))
36
+ ids = tgt_dict.encode_line(token, append_eos=False)
37
+ output["text"] = lable
38
+ output["token"] = token
39
+ output["tokenid"] = ", ".join(map(str, [t.tolist() for t in ids]))
40
+ return {utt_id: {"input": input, "output": output}}
41
+
42
+
43
+ def main():
44
+ parser = argparse.ArgumentParser()
45
+ parser.add_argument(
46
+ "--audio-dirs",
47
+ nargs="+",
48
+ default=["-"],
49
+ required=True,
50
+ help="input directories with audio files",
51
+ )
52
+ parser.add_argument(
53
+ "--labels",
54
+ required=True,
55
+ help="aggregated input labels with format <ID LABEL> per line",
56
+ type=argparse.FileType("r", encoding="UTF-8"),
57
+ )
58
+ parser.add_argument(
59
+ "--spm-model",
60
+ required=True,
61
+ help="sentencepiece model to use for encoding",
62
+ type=argparse.FileType("r", encoding="UTF-8"),
63
+ )
64
+ parser.add_argument(
65
+ "--dictionary",
66
+ required=True,
67
+ help="file to load fairseq dictionary from",
68
+ type=argparse.FileType("r", encoding="UTF-8"),
69
+ )
70
+ parser.add_argument("--audio-format", choices=["flac", "wav"], default="wav")
71
+ parser.add_argument(
72
+ "--output",
73
+ required=True,
74
+ type=argparse.FileType("w"),
75
+ help="path to save json output",
76
+ )
77
+ args = parser.parse_args()
78
+
79
+ sp = spm.SentencePieceProcessor()
80
+ sp.Load(args.spm_model.name)
81
+
82
+ tgt_dict = Dictionary.load(args.dictionary)
83
+
84
+ labels = {}
85
+ for line in args.labels:
86
+ (utt_id, label) = line.split(" ", 1)
87
+ labels[utt_id] = label
88
+ if len(labels) == 0:
89
+ raise Exception("No labels found in ", args.labels_path)
90
+
91
+ Sample = namedtuple("Sample", "aud_path utt_id")
92
+ samples = []
93
+ for path, _, files in chain.from_iterable(
94
+ os.walk(path) for path in args.audio_dirs
95
+ ):
96
+ for f in files:
97
+ if f.endswith(args.audio_format):
98
+ if len(os.path.splitext(f)) != 2:
99
+ raise Exception("Expect <utt_id.extension> file name. Got: ", f)
100
+ utt_id = os.path.splitext(f)[0]
101
+ if utt_id not in labels:
102
+ continue
103
+ samples.append(Sample(os.path.join(path, f), utt_id))
104
+
105
+ utts = {}
106
+ num_cpu = multiprocessing.cpu_count()
107
+ with concurrent.futures.ThreadPoolExecutor(max_workers=num_cpu) as executor:
108
+ future_to_sample = {
109
+ executor.submit(
110
+ process_sample, s.aud_path, labels[s.utt_id], s.utt_id, sp, tgt_dict
111
+ ): s
112
+ for s in samples
113
+ }
114
+ for future in concurrent.futures.as_completed(future_to_sample):
115
+ try:
116
+ data = future.result()
117
+ except Exception as exc:
118
+ print("generated an exception: ", exc)
119
+ else:
120
+ utts.update(data)
121
+ json.dump({"utts": utts}, args.output, indent=4)
122
+
123
+
124
+ if __name__ == "__main__":
125
+ main()
fairseq-0.10.2/examples/speech_recognition/datasets/prepare-librispeech.sh ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Prepare librispeech dataset
8
+
9
+ base_url=www.openslr.org/resources/12
10
+ train_dir=train_960
11
+
12
+ if [ "$#" -ne 2 ]; then
13
+ echo "Usage: $0 <download_dir> <out_dir>"
14
+ echo "e.g.: $0 /tmp/librispeech_raw/ ~/data/librispeech_final"
15
+ exit 1
16
+ fi
17
+
18
+ download_dir=${1%/}
19
+ out_dir=${2%/}
20
+
21
+ fairseq_root=~/fairseq-py/
22
+ mkdir -p ${out_dir}
23
+ cd ${out_dir} || exit
24
+
25
+ nbpe=5000
26
+ bpemode=unigram
27
+
28
+ if [ ! -d "$fairseq_root" ]; then
29
+ echo "$0: Please set correct fairseq_root"
30
+ exit 1
31
+ fi
32
+
33
+ echo "Data Download"
34
+ for part in dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500; do
35
+ url=$base_url/$part.tar.gz
36
+ if ! wget -P $download_dir $url; then
37
+ echo "$0: wget failed for $url"
38
+ exit 1
39
+ fi
40
+ if ! tar -C $download_dir -xvzf $download_dir/$part.tar.gz; then
41
+ echo "$0: error un-tarring archive $download_dir/$part.tar.gz"
42
+ exit 1
43
+ fi
44
+ done
45
+
46
+ echo "Merge all train packs into one"
47
+ mkdir -p ${download_dir}/LibriSpeech/${train_dir}/
48
+ for part in train-clean-100 train-clean-360 train-other-500; do
49
+ mv ${download_dir}/LibriSpeech/${part}/* $download_dir/LibriSpeech/${train_dir}/
50
+ done
51
+ echo "Merge train text"
52
+ find ${download_dir}/LibriSpeech/${train_dir}/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/${train_dir}/text
53
+
54
+ # Use combined dev-clean and dev-other as validation set
55
+ find ${download_dir}/LibriSpeech/dev-clean/ ${download_dir}/LibriSpeech/dev-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/valid_text
56
+ find ${download_dir}/LibriSpeech/test-clean/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-clean/text
57
+ find ${download_dir}/LibriSpeech/test-other/ -name '*.txt' -exec cat {} \; >> ${download_dir}/LibriSpeech/test-other/text
58
+
59
+
60
+ dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_units.txt
61
+ encoded=data/lang_char/${train_dir}_${bpemode}${nbpe}_encoded.txt
62
+ fairseq_dict=data/lang_char/${train_dir}_${bpemode}${nbpe}_fairseq_dict.txt
63
+ bpemodel=data/lang_char/${train_dir}_${bpemode}${nbpe}
64
+ echo "dictionary: ${dict}"
65
+ echo "Dictionary preparation"
66
+ mkdir -p data/lang_char/
67
+ echo "<unk> 3" > ${dict}
68
+ echo "</s> 2" >> ${dict}
69
+ echo "<pad> 1" >> ${dict}
70
+ cut -f 2- -d" " ${download_dir}/LibriSpeech/${train_dir}/text > data/lang_char/input.txt
71
+ spm_train --input=data/lang_char/input.txt --vocab_size=${nbpe} --model_type=${bpemode} --model_prefix=${bpemodel} --input_sentence_size=100000000 --unk_id=3 --eos_id=2 --pad_id=1 --bos_id=-1 --character_coverage=1
72
+ spm_encode --model=${bpemodel}.model --output_format=piece < data/lang_char/input.txt > ${encoded}
73
+ cat ${encoded} | tr ' ' '\n' | sort | uniq | awk '{print $0 " " NR+3}' >> ${dict}
74
+ cat ${encoded} | tr ' ' '\n' | sort | uniq -c | awk '{print $2 " " $1}' > ${fairseq_dict}
75
+ wc -l ${dict}
76
+
77
+ echo "Prepare train and test jsons"
78
+ for part in train_960 test-other test-clean; do
79
+ python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/${part} --labels ${download_dir}/LibriSpeech/${part}/text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output ${part}.json
80
+ done
81
+ # fairseq expects to find train.json and valid.json during training
82
+ mv train_960.json train.json
83
+
84
+ echo "Prepare valid json"
85
+ python ${fairseq_root}/examples/speech_recognition/datasets/asr_prep_json.py --audio-dirs ${download_dir}/LibriSpeech/dev-clean ${download_dir}/LibriSpeech/dev-other --labels ${download_dir}/LibriSpeech/valid_text --spm-model ${bpemodel}.model --audio-format flac --dictionary ${fairseq_dict} --output valid.json
86
+
87
+ cp ${fairseq_dict} ./dict.txt
88
+ cp ${bpemodel}.model ./spm.model
fairseq-0.10.2/examples/speech_recognition/infer.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Run inference for pre-processed data with a trained model.
9
+ """
10
+
11
+ import logging
12
+ import math
13
+ import os
14
+ import sys
15
+
16
+ import editdistance
17
+ import numpy as np
18
+ import torch
19
+ from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
20
+ from fairseq.data.data_utils import post_process
21
+ from fairseq.logging.meters import StopwatchMeter, TimeMeter
22
+
23
+
24
+ logging.basicConfig()
25
+ logging.root.setLevel(logging.INFO)
26
+ logging.basicConfig(level=logging.INFO)
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ def add_asr_eval_argument(parser):
31
+ parser.add_argument("--kspmodel", default=None, help="sentence piece model")
32
+ parser.add_argument(
33
+ "--wfstlm", default=None, help="wfstlm on dictonary output units"
34
+ )
35
+ parser.add_argument(
36
+ "--rnnt_decoding_type",
37
+ default="greedy",
38
+ help="wfstlm on dictonary\
39
+ output units",
40
+ )
41
+ try:
42
+ parser.add_argument(
43
+ "--lm-weight",
44
+ "--lm_weight",
45
+ type=float,
46
+ default=0.2,
47
+ help="weight for lm while interpolating with neural score",
48
+ )
49
+ except:
50
+ pass
51
+ parser.add_argument(
52
+ "--rnnt_len_penalty", default=-0.5, help="rnnt length penalty on word level"
53
+ )
54
+ parser.add_argument(
55
+ "--w2l-decoder",
56
+ choices=["viterbi", "kenlm", "fairseqlm"],
57
+ help="use a w2l decoder",
58
+ )
59
+ parser.add_argument("--lexicon", help="lexicon for w2l decoder")
60
+ parser.add_argument("--unit-lm", action="store_true", help="if using a unit lm")
61
+ parser.add_argument("--kenlm-model", "--lm-model", help="lm model for w2l decoder")
62
+ parser.add_argument("--beam-threshold", type=float, default=25.0)
63
+ parser.add_argument("--beam-size-token", type=float, default=100)
64
+ parser.add_argument("--word-score", type=float, default=1.0)
65
+ parser.add_argument("--unk-weight", type=float, default=-math.inf)
66
+ parser.add_argument("--sil-weight", type=float, default=0.0)
67
+ parser.add_argument(
68
+ "--dump-emissions",
69
+ type=str,
70
+ default=None,
71
+ help="if present, dumps emissions into this file and exits",
72
+ )
73
+ parser.add_argument(
74
+ "--dump-features",
75
+ type=str,
76
+ default=None,
77
+ help="if present, dumps features into this file and exits",
78
+ )
79
+ parser.add_argument(
80
+ "--load-emissions",
81
+ type=str,
82
+ default=None,
83
+ help="if present, loads emissions from this file",
84
+ )
85
+ return parser
86
+
87
+
88
+ def check_args(args):
89
+ # assert args.path is not None, "--path required for generation!"
90
+ # assert args.results_path is not None, "--results_path required for generation!"
91
+ assert (
92
+ not args.sampling or args.nbest == args.beam
93
+ ), "--sampling requires --nbest to be equal to --beam"
94
+ assert (
95
+ args.replace_unk is None or args.raw_text
96
+ ), "--replace-unk requires a raw text dataset (--raw-text)"
97
+
98
+
99
+ def get_dataset_itr(args, task, models):
100
+ return task.get_batch_iterator(
101
+ dataset=task.dataset(args.gen_subset),
102
+ max_tokens=args.max_tokens,
103
+ max_sentences=args.batch_size,
104
+ max_positions=(sys.maxsize, sys.maxsize),
105
+ ignore_invalid_inputs=args.skip_invalid_size_inputs_valid_test,
106
+ required_batch_size_multiple=args.required_batch_size_multiple,
107
+ num_shards=args.num_shards,
108
+ shard_id=args.shard_id,
109
+ num_workers=args.num_workers,
110
+ data_buffer_size=args.data_buffer_size,
111
+ ).next_epoch_itr(shuffle=False)
112
+
113
+
114
+ def process_predictions(
115
+ args, hypos, sp, tgt_dict, target_tokens, res_files, speaker, id
116
+ ):
117
+ for hypo in hypos[: min(len(hypos), args.nbest)]:
118
+ hyp_pieces = tgt_dict.string(hypo["tokens"].int().cpu())
119
+
120
+ if "words" in hypo:
121
+ hyp_words = " ".join(hypo["words"])
122
+ else:
123
+ hyp_words = post_process(hyp_pieces, args.remove_bpe)
124
+
125
+ if res_files is not None:
126
+ print(
127
+ "{} ({}-{})".format(hyp_pieces, speaker, id),
128
+ file=res_files["hypo.units"],
129
+ )
130
+ print(
131
+ "{} ({}-{})".format(hyp_words, speaker, id),
132
+ file=res_files["hypo.words"],
133
+ )
134
+
135
+ tgt_pieces = tgt_dict.string(target_tokens)
136
+ tgt_words = post_process(tgt_pieces, args.remove_bpe)
137
+
138
+ if res_files is not None:
139
+ print(
140
+ "{} ({}-{})".format(tgt_pieces, speaker, id),
141
+ file=res_files["ref.units"],
142
+ )
143
+ print(
144
+ "{} ({}-{})".format(tgt_words, speaker, id), file=res_files["ref.words"]
145
+ )
146
+ # only score top hypothesis
147
+ if not args.quiet:
148
+ logger.debug("HYPO:" + hyp_words)
149
+ logger.debug("TARGET:" + tgt_words)
150
+ logger.debug("___________________")
151
+
152
+ hyp_words = hyp_words.split()
153
+ tgt_words = tgt_words.split()
154
+ return editdistance.eval(hyp_words, tgt_words), len(tgt_words)
155
+
156
+
157
+ def prepare_result_files(args):
158
+ def get_res_file(file_prefix):
159
+ if args.num_shards > 1:
160
+ file_prefix = f"{args.shard_id}_{file_prefix}"
161
+ path = os.path.join(
162
+ args.results_path,
163
+ "{}-{}-{}.txt".format(
164
+ file_prefix, os.path.basename(args.path), args.gen_subset
165
+ ),
166
+ )
167
+ return open(path, "w", buffering=1)
168
+
169
+ if not args.results_path:
170
+ return None
171
+
172
+ return {
173
+ "hypo.words": get_res_file("hypo.word"),
174
+ "hypo.units": get_res_file("hypo.units"),
175
+ "ref.words": get_res_file("ref.word"),
176
+ "ref.units": get_res_file("ref.units"),
177
+ }
178
+
179
+
180
+ def load_models_and_criterions(
181
+ filenames, data_path, arg_overrides=None, task=None, model_state=None
182
+ ):
183
+ models = []
184
+ criterions = []
185
+
186
+ if arg_overrides is None:
187
+ arg_overrides = {}
188
+
189
+ arg_overrides["wer_args"] = None
190
+ arg_overrides["data"] = data_path
191
+
192
+ if filenames is None:
193
+ assert model_state is not None
194
+ filenames = [0]
195
+ else:
196
+ filenames = filenames.split(":")
197
+
198
+ for filename in filenames:
199
+ if model_state is None:
200
+ if not os.path.exists(filename):
201
+ raise IOError("Model file not found: {}".format(filename))
202
+ state = checkpoint_utils.load_checkpoint_to_cpu(filename, arg_overrides)
203
+ else:
204
+ state = model_state
205
+
206
+ args = state["args"]
207
+ if task is None:
208
+ task = tasks.setup_task(args)
209
+ model = task.build_model(args)
210
+ model.load_state_dict(state["model"], strict=True)
211
+ models.append(model)
212
+
213
+ criterion = task.build_criterion(args)
214
+ if "criterion" in state:
215
+ criterion.load_state_dict(state["criterion"], strict=True)
216
+ criterions.append(criterion)
217
+ return models, criterions, args
218
+
219
+
220
+ def optimize_models(args, use_cuda, models):
221
+ """Optimize ensemble for generation"""
222
+ for model in models:
223
+ model.make_generation_fast_(
224
+ beamable_mm_beam_size=None if args.no_beamable_mm else args.beam,
225
+ need_attn=args.print_alignment,
226
+ )
227
+ if args.fp16:
228
+ model.half()
229
+ if use_cuda:
230
+ model.cuda()
231
+
232
+
233
+ class ExistingEmissionsDecoder(object):
234
+ def __init__(self, decoder, emissions):
235
+ self.decoder = decoder
236
+ self.emissions = emissions
237
+
238
+ def generate(self, models, sample, **unused):
239
+ ids = sample["id"].cpu().numpy()
240
+ try:
241
+ emissions = np.stack(self.emissions[ids])
242
+ except:
243
+ print([x.shape for x in self.emissions[ids]])
244
+ raise Exception("invalid sizes")
245
+ emissions = torch.from_numpy(emissions)
246
+ return self.decoder.decode(emissions)
247
+
248
+
249
+ def main(args, task=None, model_state=None):
250
+ check_args(args)
251
+
252
+ if args.max_tokens is None and args.batch_size is None:
253
+ args.max_tokens = 4000000
254
+ logger.info(args)
255
+
256
+ use_cuda = torch.cuda.is_available() and not args.cpu
257
+
258
+ if task is None:
259
+ # Load dataset splits
260
+ task = tasks.setup_task(args)
261
+ task.load_dataset(args.gen_subset)
262
+
263
+ logger.info(
264
+ "| {} {} {} examples".format(
265
+ args.data, args.gen_subset, len(task.dataset(args.gen_subset))
266
+ )
267
+ )
268
+
269
+ # Set dictionary
270
+ tgt_dict = task.target_dictionary
271
+
272
+ logger.info("| decoding with criterion {}".format(args.criterion))
273
+
274
+ # Load ensemble
275
+
276
+ if args.load_emissions:
277
+ models, criterions = [], []
278
+ else:
279
+ logger.info("| loading model(s) from {}".format(args.path))
280
+ models, criterions, _ = load_models_and_criterions(
281
+ args.path,
282
+ data_path=args.data,
283
+ arg_overrides=eval(args.model_overrides), # noqa
284
+ task=task,
285
+ model_state=model_state,
286
+ )
287
+ optimize_models(args, use_cuda, models)
288
+
289
+ # hack to pass transitions to W2lDecoder
290
+ if args.criterion == "asg_loss":
291
+ trans = criterions[0].asg.trans.data
292
+ args.asg_transitions = torch.flatten(trans).tolist()
293
+
294
+ # Load dataset (possibly sharded)
295
+ itr = get_dataset_itr(args, task, models)
296
+
297
+ # Initialize generator
298
+ gen_timer = StopwatchMeter()
299
+
300
+ def build_generator(args):
301
+ w2l_decoder = getattr(args, "w2l_decoder", None)
302
+ if w2l_decoder == "viterbi":
303
+ from examples.speech_recognition.w2l_decoder import W2lViterbiDecoder
304
+
305
+ return W2lViterbiDecoder(args, task.target_dictionary)
306
+ elif w2l_decoder == "kenlm":
307
+ from examples.speech_recognition.w2l_decoder import W2lKenLMDecoder
308
+
309
+ return W2lKenLMDecoder(args, task.target_dictionary)
310
+ elif w2l_decoder == "fairseqlm":
311
+ from examples.speech_recognition.w2l_decoder import W2lFairseqLMDecoder
312
+
313
+ return W2lFairseqLMDecoder(args, task.target_dictionary)
314
+ else:
315
+ print(
316
+ "only wav2letter decoders with (viterbi, kenlm, fairseqlm) options are supported at the moment"
317
+ )
318
+
319
+ # please do not touch this unless you test both generate.py and infer.py with audio_pretraining task
320
+ generator = build_generator(args)
321
+
322
+ if args.load_emissions:
323
+ generator = ExistingEmissionsDecoder(
324
+ generator, np.load(args.load_emissions, allow_pickle=True)
325
+ )
326
+ logger.info("loaded emissions from " + args.load_emissions)
327
+
328
+ num_sentences = 0
329
+
330
+ if args.results_path is not None and not os.path.exists(args.results_path):
331
+ os.makedirs(args.results_path)
332
+
333
+ max_source_pos = (
334
+ utils.resolve_max_positions(
335
+ task.max_positions(), *[model.max_positions() for model in models]
336
+ ),
337
+ )
338
+
339
+ if max_source_pos is not None:
340
+ max_source_pos = max_source_pos[0]
341
+ if max_source_pos is not None:
342
+ max_source_pos = max_source_pos[0] - 1
343
+
344
+ if args.dump_emissions:
345
+ emissions = {}
346
+ if args.dump_features:
347
+ features = {}
348
+ models[0].bert.proj = None
349
+ else:
350
+ res_files = prepare_result_files(args)
351
+ errs_t = 0
352
+ lengths_t = 0
353
+ with progress_bar.build_progress_bar(args, itr) as t:
354
+ wps_meter = TimeMeter()
355
+ for sample in t:
356
+ sample = utils.move_to_cuda(sample) if use_cuda else sample
357
+ if "net_input" not in sample:
358
+ continue
359
+
360
+ prefix_tokens = None
361
+ if args.prefix_size > 0:
362
+ prefix_tokens = sample["target"][:, : args.prefix_size]
363
+
364
+ gen_timer.start()
365
+ if args.dump_emissions:
366
+ with torch.no_grad():
367
+ encoder_out = models[0](**sample["net_input"])
368
+ emm = models[0].get_normalized_probs(encoder_out, log_probs=True)
369
+ emm = emm.transpose(0, 1).cpu().numpy()
370
+ for i, id in enumerate(sample["id"]):
371
+ emissions[id.item()] = emm[i]
372
+ continue
373
+ elif args.dump_features:
374
+ with torch.no_grad():
375
+ encoder_out = models[0](**sample["net_input"])
376
+ feat = encoder_out["encoder_out"].transpose(0, 1).cpu().numpy()
377
+ for i, id in enumerate(sample["id"]):
378
+ padding = (
379
+ encoder_out["encoder_padding_mask"][i].cpu().numpy()
380
+ if encoder_out["encoder_padding_mask"] is not None
381
+ else None
382
+ )
383
+ features[id.item()] = (feat[i], padding)
384
+ continue
385
+ hypos = task.inference_step(generator, models, sample, prefix_tokens)
386
+ num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos)
387
+ gen_timer.stop(num_generated_tokens)
388
+
389
+ for i, sample_id in enumerate(sample["id"].tolist()):
390
+ speaker = None
391
+ # id = task.dataset(args.gen_subset).ids[int(sample_id)]
392
+ id = sample_id
393
+ toks = (
394
+ sample["target"][i, :]
395
+ if "target_label" not in sample
396
+ else sample["target_label"][i, :]
397
+ )
398
+ target_tokens = utils.strip_pad(toks, tgt_dict.pad()).int().cpu()
399
+ # Process top predictions
400
+ errs, length = process_predictions(
401
+ args,
402
+ hypos[i],
403
+ None,
404
+ tgt_dict,
405
+ target_tokens,
406
+ res_files,
407
+ speaker,
408
+ id,
409
+ )
410
+ errs_t += errs
411
+ lengths_t += length
412
+
413
+ wps_meter.update(num_generated_tokens)
414
+ t.log({"wps": round(wps_meter.avg)})
415
+ num_sentences += (
416
+ sample["nsentences"] if "nsentences" in sample else sample["id"].numel()
417
+ )
418
+
419
+ wer = None
420
+ if args.dump_emissions:
421
+ emm_arr = []
422
+ for i in range(len(emissions)):
423
+ emm_arr.append(emissions[i])
424
+ np.save(args.dump_emissions, emm_arr)
425
+ logger.info(f"saved {len(emissions)} emissions to {args.dump_emissions}")
426
+ elif args.dump_features:
427
+ feat_arr = []
428
+ for i in range(len(features)):
429
+ feat_arr.append(features[i])
430
+ np.save(args.dump_features, feat_arr)
431
+ logger.info(f"saved {len(features)} emissions to {args.dump_features}")
432
+ else:
433
+ if lengths_t > 0:
434
+ wer = errs_t * 100.0 / lengths_t
435
+ logger.info(f"WER: {wer}")
436
+
437
+ logger.info(
438
+ "| Processed {} sentences ({} tokens) in {:.1f}s ({:.2f}"
439
+ "sentences/s, {:.2f} tokens/s)".format(
440
+ num_sentences,
441
+ gen_timer.n,
442
+ gen_timer.sum,
443
+ num_sentences / gen_timer.sum,
444
+ 1.0 / gen_timer.avg,
445
+ )
446
+ )
447
+ logger.info("| Generate {} with beam={}".format(args.gen_subset, args.beam))
448
+ return task, wer
449
+
450
+
451
+ def make_parser():
452
+ parser = options.get_generation_parser()
453
+ parser = add_asr_eval_argument(parser)
454
+ return parser
455
+
456
+
457
+ def cli_main():
458
+ parser = make_parser()
459
+ args = options.parse_args_and_arch(parser)
460
+ main(args)
461
+
462
+
463
+ if __name__ == "__main__":
464
+ cli_main()
fairseq-0.10.2/examples/speech_recognition/utils/wer_utils.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ from __future__ import absolute_import, division, print_function, unicode_literals
9
+
10
+ import re
11
+ from collections import deque
12
+ from enum import Enum
13
+
14
+ import numpy as np
15
+
16
+
17
+ """
18
+ Utility modules for computation of Word Error Rate,
19
+ Alignments, as well as more granular metrics like
20
+ deletion, insersion and substitutions.
21
+ """
22
+
23
+
24
+ class Code(Enum):
25
+ match = 1
26
+ substitution = 2
27
+ insertion = 3
28
+ deletion = 4
29
+
30
+
31
+ class Token(object):
32
+ def __init__(self, lbl="", st=np.nan, en=np.nan):
33
+ if np.isnan(st):
34
+ self.label, self.start, self.end = "", 0.0, 0.0
35
+ else:
36
+ self.label, self.start, self.end = lbl, st, en
37
+
38
+
39
+ class AlignmentResult(object):
40
+ def __init__(self, refs, hyps, codes, score):
41
+ self.refs = refs # std::deque<int>
42
+ self.hyps = hyps # std::deque<int>
43
+ self.codes = codes # std::deque<Code>
44
+ self.score = score # float
45
+
46
+
47
+ def coordinate_to_offset(row, col, ncols):
48
+ return int(row * ncols + col)
49
+
50
+
51
+ def offset_to_row(offset, ncols):
52
+ return int(offset / ncols)
53
+
54
+
55
+ def offset_to_col(offset, ncols):
56
+ return int(offset % ncols)
57
+
58
+
59
+ def trimWhitespace(str):
60
+ return re.sub(" +", " ", re.sub(" *$", "", re.sub("^ *", "", str)))
61
+
62
+
63
+ def str2toks(str):
64
+ pieces = trimWhitespace(str).split(" ")
65
+ toks = []
66
+ for p in pieces:
67
+ toks.append(Token(p, 0.0, 0.0))
68
+ return toks
69
+
70
+
71
+ class EditDistance(object):
72
+ def __init__(self, time_mediated):
73
+ self.time_mediated_ = time_mediated
74
+ self.scores_ = np.nan # Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
75
+ self.backtraces_ = (
76
+ np.nan
77
+ ) # Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic> backtraces_;
78
+ self.confusion_pairs_ = {}
79
+
80
+ def cost(self, ref, hyp, code):
81
+ if self.time_mediated_:
82
+ if code == Code.match:
83
+ return abs(ref.start - hyp.start) + abs(ref.end - hyp.end)
84
+ elif code == Code.insertion:
85
+ return hyp.end - hyp.start
86
+ elif code == Code.deletion:
87
+ return ref.end - ref.start
88
+ else: # substitution
89
+ return abs(ref.start - hyp.start) + abs(ref.end - hyp.end) + 0.1
90
+ else:
91
+ if code == Code.match:
92
+ return 0
93
+ elif code == Code.insertion or code == Code.deletion:
94
+ return 3
95
+ else: # substitution
96
+ return 4
97
+
98
+ def get_result(self, refs, hyps):
99
+ res = AlignmentResult(refs=deque(), hyps=deque(), codes=deque(), score=np.nan)
100
+
101
+ num_rows, num_cols = self.scores_.shape
102
+ res.score = self.scores_[num_rows - 1, num_cols - 1]
103
+
104
+ curr_offset = coordinate_to_offset(num_rows - 1, num_cols - 1, num_cols)
105
+
106
+ while curr_offset != 0:
107
+ curr_row = offset_to_row(curr_offset, num_cols)
108
+ curr_col = offset_to_col(curr_offset, num_cols)
109
+
110
+ prev_offset = self.backtraces_[curr_row, curr_col]
111
+
112
+ prev_row = offset_to_row(prev_offset, num_cols)
113
+ prev_col = offset_to_col(prev_offset, num_cols)
114
+
115
+ res.refs.appendleft(curr_row - 1) # Note: this was .push_front() in C++
116
+ res.hyps.appendleft(curr_col - 1)
117
+ if curr_row - 1 == prev_row and curr_col == prev_col:
118
+ res.codes.appendleft(Code.deletion)
119
+ elif curr_row == prev_row and curr_col - 1 == prev_col:
120
+ res.codes.appendleft(Code.insertion)
121
+ else:
122
+ # assert(curr_row - 1 == prev_row and curr_col - 1 == prev_col)
123
+ ref_str = refs[res.refs[0]].label
124
+ hyp_str = hyps[res.hyps[0]].label
125
+
126
+ if ref_str == hyp_str:
127
+ res.codes.appendleft(Code.match)
128
+ else:
129
+ res.codes.appendleft(Code.substitution)
130
+
131
+ confusion_pair = "%s -> %s" % (ref_str, hyp_str)
132
+ if confusion_pair not in self.confusion_pairs_:
133
+ self.confusion_pairs_[confusion_pair] = 1
134
+ else:
135
+ self.confusion_pairs_[confusion_pair] += 1
136
+
137
+ curr_offset = prev_offset
138
+
139
+ return res
140
+
141
+ def align(self, refs, hyps):
142
+ if len(refs) == 0 and len(hyps) == 0:
143
+ return np.nan
144
+
145
+ # NOTE: we're not resetting the values in these matrices because every value
146
+ # will be overridden in the loop below. If this assumption doesn't hold,
147
+ # be sure to set all entries in self.scores_ and self.backtraces_ to 0.
148
+ self.scores_ = np.zeros((len(refs) + 1, len(hyps) + 1))
149
+ self.backtraces_ = np.zeros((len(refs) + 1, len(hyps) + 1))
150
+
151
+ num_rows, num_cols = self.scores_.shape
152
+
153
+ for i in range(num_rows):
154
+ for j in range(num_cols):
155
+ if i == 0 and j == 0:
156
+ self.scores_[i, j] = 0.0
157
+ self.backtraces_[i, j] = 0
158
+ continue
159
+
160
+ if i == 0:
161
+ self.scores_[i, j] = self.scores_[i, j - 1] + self.cost(
162
+ None, hyps[j - 1], Code.insertion
163
+ )
164
+ self.backtraces_[i, j] = coordinate_to_offset(i, j - 1, num_cols)
165
+ continue
166
+
167
+ if j == 0:
168
+ self.scores_[i, j] = self.scores_[i - 1, j] + self.cost(
169
+ refs[i - 1], None, Code.deletion
170
+ )
171
+ self.backtraces_[i, j] = coordinate_to_offset(i - 1, j, num_cols)
172
+ continue
173
+
174
+ # Below here both i and j are greater than 0
175
+ ref = refs[i - 1]
176
+ hyp = hyps[j - 1]
177
+ best_score = self.scores_[i - 1, j - 1] + (
178
+ self.cost(ref, hyp, Code.match)
179
+ if (ref.label == hyp.label)
180
+ else self.cost(ref, hyp, Code.substitution)
181
+ )
182
+
183
+ prev_row = i - 1
184
+ prev_col = j - 1
185
+ ins = self.scores_[i, j - 1] + self.cost(None, hyp, Code.insertion)
186
+ if ins < best_score:
187
+ best_score = ins
188
+ prev_row = i
189
+ prev_col = j - 1
190
+
191
+ delt = self.scores_[i - 1, j] + self.cost(ref, None, Code.deletion)
192
+ if delt < best_score:
193
+ best_score = delt
194
+ prev_row = i - 1
195
+ prev_col = j
196
+
197
+ self.scores_[i, j] = best_score
198
+ self.backtraces_[i, j] = coordinate_to_offset(
199
+ prev_row, prev_col, num_cols
200
+ )
201
+
202
+ return self.get_result(refs, hyps)
203
+
204
+
205
+ class WERTransformer(object):
206
+ def __init__(self, hyp_str, ref_str, verbose=True):
207
+ self.ed_ = EditDistance(False)
208
+ self.id2oracle_errs_ = {}
209
+ self.utts_ = 0
210
+ self.words_ = 0
211
+ self.insertions_ = 0
212
+ self.deletions_ = 0
213
+ self.substitutions_ = 0
214
+
215
+ self.process(["dummy_str", hyp_str, ref_str])
216
+
217
+ if verbose:
218
+ print("'%s' vs '%s'" % (hyp_str, ref_str))
219
+ self.report_result()
220
+
221
+ def process(self, input): # std::vector<std::string>&& input
222
+ if len(input) < 3:
223
+ print(
224
+ "Input must be of the form <id> ... <hypo> <ref> , got ",
225
+ len(input),
226
+ " inputs:",
227
+ )
228
+ return None
229
+
230
+ # Align
231
+ # std::vector<Token> hyps;
232
+ # std::vector<Token> refs;
233
+
234
+ hyps = str2toks(input[-2])
235
+ refs = str2toks(input[-1])
236
+
237
+ alignment = self.ed_.align(refs, hyps)
238
+ if alignment is None:
239
+ print("Alignment is null")
240
+ return np.nan
241
+
242
+ # Tally errors
243
+ ins = 0
244
+ dels = 0
245
+ subs = 0
246
+ for code in alignment.codes:
247
+ if code == Code.substitution:
248
+ subs += 1
249
+ elif code == Code.insertion:
250
+ ins += 1
251
+ elif code == Code.deletion:
252
+ dels += 1
253
+
254
+ # Output
255
+ row = input
256
+ row.append(str(len(refs)))
257
+ row.append(str(ins))
258
+ row.append(str(dels))
259
+ row.append(str(subs))
260
+ # print(row)
261
+
262
+ # Accumulate
263
+ kIdIndex = 0
264
+ kNBestSep = "/"
265
+
266
+ pieces = input[kIdIndex].split(kNBestSep)
267
+
268
+ if len(pieces) == 0:
269
+ print(
270
+ "Error splitting ",
271
+ input[kIdIndex],
272
+ " on '",
273
+ kNBestSep,
274
+ "', got empty list",
275
+ )
276
+ return np.nan
277
+
278
+ id = pieces[0]
279
+ if id not in self.id2oracle_errs_:
280
+ self.utts_ += 1
281
+ self.words_ += len(refs)
282
+ self.insertions_ += ins
283
+ self.deletions_ += dels
284
+ self.substitutions_ += subs
285
+ self.id2oracle_errs_[id] = [ins, dels, subs]
286
+ else:
287
+ curr_err = ins + dels + subs
288
+ prev_err = np.sum(self.id2oracle_errs_[id])
289
+ if curr_err < prev_err:
290
+ self.id2oracle_errs_[id] = [ins, dels, subs]
291
+
292
+ return 0
293
+
294
+ def report_result(self):
295
+ # print("---------- Summary ---------------")
296
+ if self.words_ == 0:
297
+ print("No words counted")
298
+ return
299
+
300
+ # 1-best
301
+ best_wer = (
302
+ 100.0
303
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
304
+ / self.words_
305
+ )
306
+
307
+ print(
308
+ "\tWER = %0.2f%% (%i utts, %i words, %0.2f%% ins, "
309
+ "%0.2f%% dels, %0.2f%% subs)"
310
+ % (
311
+ best_wer,
312
+ self.utts_,
313
+ self.words_,
314
+ 100.0 * self.insertions_ / self.words_,
315
+ 100.0 * self.deletions_ / self.words_,
316
+ 100.0 * self.substitutions_ / self.words_,
317
+ )
318
+ )
319
+
320
+ def wer(self):
321
+ if self.words_ == 0:
322
+ wer = np.nan
323
+ else:
324
+ wer = (
325
+ 100.0
326
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
327
+ / self.words_
328
+ )
329
+ return wer
330
+
331
+ def stats(self):
332
+ if self.words_ == 0:
333
+ stats = {}
334
+ else:
335
+ wer = (
336
+ 100.0
337
+ * (self.insertions_ + self.deletions_ + self.substitutions_)
338
+ / self.words_
339
+ )
340
+ stats = dict(
341
+ {
342
+ "wer": wer,
343
+ "utts": self.utts_,
344
+ "numwords": self.words_,
345
+ "ins": self.insertions_,
346
+ "dels": self.deletions_,
347
+ "subs": self.substitutions_,
348
+ "confusion_pairs": self.ed_.confusion_pairs_,
349
+ }
350
+ )
351
+ return stats
352
+
353
+
354
+ def calc_wer(hyp_str, ref_str):
355
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
356
+ return t.wer()
357
+
358
+
359
+ def calc_wer_stats(hyp_str, ref_str):
360
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
361
+ return t.stats()
362
+
363
+
364
+ def get_wer_alignment_codes(hyp_str, ref_str):
365
+ """
366
+ INPUT: hypothesis string, reference string
367
+ OUTPUT: List of alignment codes (intermediate results from WER computation)
368
+ """
369
+ t = WERTransformer(hyp_str, ref_str, verbose=0)
370
+ return t.ed_.align(str2toks(ref_str), str2toks(hyp_str)).codes
371
+
372
+
373
+ def merge_counts(x, y):
374
+ # Merge two hashes which have 'counts' as their values
375
+ # This can be used for example to merge confusion pair counts
376
+ # conf_pairs = merge_counts(conf_pairs, stats['confusion_pairs'])
377
+ for k, v in y.items():
378
+ if k not in x:
379
+ x[k] = 0
380
+ x[k] += v
381
+ return x
fairseq-0.10.2/examples/speech_recognition/w2l_decoder.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ #
5
+ # This source code is licensed under the MIT license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ """
9
+ Wav2letter decoders.
10
+ """
11
+
12
+ import gc
13
+ import itertools as it
14
+ import os.path as osp
15
+ import warnings
16
+ from collections import deque, namedtuple
17
+
18
+ import numpy as np
19
+ import torch
20
+ from examples.speech_recognition.data.replabels import unpack_replabels
21
+ from fairseq import tasks
22
+ from fairseq.utils import apply_to_sample
23
+
24
+
25
+ try:
26
+ from wav2letter.common import create_word_dict, load_words
27
+ from wav2letter.criterion import CpuViterbiPath, get_data_ptr_as_bytes
28
+ from wav2letter.decoder import (
29
+ CriterionType,
30
+ DecoderOptions,
31
+ KenLM,
32
+ LM,
33
+ LMState,
34
+ SmearingMode,
35
+ Trie,
36
+ LexiconDecoder,
37
+ LexiconFreeDecoder,
38
+ )
39
+ except:
40
+ warnings.warn(
41
+ "wav2letter python bindings are required to use this functionality. Please install from https://github.com/facebookresearch/wav2letter/wiki/Python-bindings"
42
+ )
43
+ LM = object
44
+ LMState = object
45
+
46
+
47
+ class W2lDecoder(object):
48
+ def __init__(self, args, tgt_dict):
49
+ self.tgt_dict = tgt_dict
50
+ self.vocab_size = len(tgt_dict)
51
+ self.nbest = args.nbest
52
+
53
+ # criterion-specific init
54
+ if args.criterion == "ctc":
55
+ self.criterion_type = CriterionType.CTC
56
+ self.blank = (
57
+ tgt_dict.index("<ctc_blank>")
58
+ if "<ctc_blank>" in tgt_dict.indices
59
+ else tgt_dict.bos()
60
+ )
61
+ self.asg_transitions = None
62
+ elif args.criterion == "asg_loss":
63
+ self.criterion_type = CriterionType.ASG
64
+ self.blank = -1
65
+ self.asg_transitions = args.asg_transitions
66
+ self.max_replabel = args.max_replabel
67
+ assert len(self.asg_transitions) == self.vocab_size ** 2
68
+ else:
69
+ raise RuntimeError(f"unknown criterion: {args.criterion}")
70
+
71
+ def generate(self, models, sample, **unused):
72
+ """Generate a batch of inferences."""
73
+ # model.forward normally channels prev_output_tokens into the decoder
74
+ # separately, but SequenceGenerator directly calls model.encoder
75
+ encoder_input = {
76
+ k: v for k, v in sample["net_input"].items() if k != "prev_output_tokens"
77
+ }
78
+ emissions = self.get_emissions(models, encoder_input)
79
+ return self.decode(emissions)
80
+
81
+ def get_emissions(self, models, encoder_input):
82
+ """Run encoder and normalize emissions"""
83
+ # encoder_out = models[0].encoder(**encoder_input)
84
+ encoder_out = models[0](**encoder_input)
85
+ if self.criterion_type == CriterionType.CTC:
86
+ emissions = models[0].get_normalized_probs(encoder_out, log_probs=True)
87
+ elif self.criterion_type == CriterionType.ASG:
88
+ emissions = encoder_out["encoder_out"]
89
+ return emissions.transpose(0, 1).float().cpu().contiguous()
90
+
91
+ def get_tokens(self, idxs):
92
+ """Normalize tokens by handling CTC blank, ASG replabels, etc."""
93
+ idxs = (g[0] for g in it.groupby(idxs))
94
+ if self.criterion_type == CriterionType.CTC:
95
+ idxs = filter(lambda x: x != self.blank, idxs)
96
+ elif self.criterion_type == CriterionType.ASG:
97
+ idxs = filter(lambda x: x >= 0, idxs)
98
+ idxs = unpack_replabels(list(idxs), self.tgt_dict, self.max_replabel)
99
+ return torch.LongTensor(list(idxs))
100
+
101
+
102
+ class W2lViterbiDecoder(W2lDecoder):
103
+ def __init__(self, args, tgt_dict):
104
+ super().__init__(args, tgt_dict)
105
+
106
+ def decode(self, emissions):
107
+ B, T, N = emissions.size()
108
+ hypos = []
109
+ if self.asg_transitions is None:
110
+ transitions = torch.FloatTensor(N, N).zero_()
111
+ else:
112
+ transitions = torch.FloatTensor(self.asg_transitions).view(N, N)
113
+ viterbi_path = torch.IntTensor(B, T)
114
+ workspace = torch.ByteTensor(CpuViterbiPath.get_workspace_size(B, T, N))
115
+ CpuViterbiPath.compute(
116
+ B,
117
+ T,
118
+ N,
119
+ get_data_ptr_as_bytes(emissions),
120
+ get_data_ptr_as_bytes(transitions),
121
+ get_data_ptr_as_bytes(viterbi_path),
122
+ get_data_ptr_as_bytes(workspace),
123
+ )
124
+ return [
125
+ [{"tokens": self.get_tokens(viterbi_path[b].tolist()), "score": 0}]
126
+ for b in range(B)
127
+ ]
128
+
129
+
130
+ class W2lKenLMDecoder(W2lDecoder):
131
+ def __init__(self, args, tgt_dict):
132
+ super().__init__(args, tgt_dict)
133
+
134
+ self.silence = (
135
+ tgt_dict.index("<ctc_blank>")
136
+ if "<ctc_blank>" in tgt_dict.indices
137
+ else tgt_dict.bos()
138
+ )
139
+ self.lexicon = load_words(args.lexicon)
140
+ self.word_dict = create_word_dict(self.lexicon)
141
+ self.unk_word = self.word_dict.get_index("<unk>")
142
+
143
+ self.lm = KenLM(args.kenlm_model, self.word_dict)
144
+ self.trie = Trie(self.vocab_size, self.silence)
145
+
146
+ start_state = self.lm.start(False)
147
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
148
+ word_idx = self.word_dict.get_index(word)
149
+ _, score = self.lm.score(start_state, word_idx)
150
+ for spelling in spellings:
151
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
152
+ assert (
153
+ tgt_dict.unk() not in spelling_idxs
154
+ ), f"{spelling} {spelling_idxs}"
155
+ self.trie.insert(spelling_idxs, word_idx, score)
156
+ self.trie.smear(SmearingMode.MAX)
157
+
158
+ self.decoder_opts = DecoderOptions(
159
+ args.beam,
160
+ int(getattr(args, "beam_size_token", len(tgt_dict))),
161
+ args.beam_threshold,
162
+ args.lm_weight,
163
+ args.word_score,
164
+ args.unk_weight,
165
+ args.sil_weight,
166
+ 0,
167
+ False,
168
+ self.criterion_type,
169
+ )
170
+
171
+ if self.asg_transitions is None:
172
+ N = 768
173
+ # self.asg_transitions = torch.FloatTensor(N, N).zero_()
174
+ self.asg_transitions = []
175
+
176
+ self.decoder = LexiconDecoder(
177
+ self.decoder_opts,
178
+ self.trie,
179
+ self.lm,
180
+ self.silence,
181
+ self.blank,
182
+ self.unk_word,
183
+ self.asg_transitions,
184
+ False,
185
+ )
186
+
187
+ def decode(self, emissions):
188
+ B, T, N = emissions.size()
189
+ hypos = []
190
+ for b in range(B):
191
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
192
+ results = self.decoder.decode(emissions_ptr, T, N)
193
+
194
+ nbest_results = results[: self.nbest]
195
+ hypos.append(
196
+ [
197
+ {
198
+ "tokens": self.get_tokens(result.tokens),
199
+ "score": result.score,
200
+ "words": [
201
+ self.word_dict.get_entry(x) for x in result.words if x >= 0
202
+ ],
203
+ }
204
+ for result in nbest_results
205
+ ]
206
+ )
207
+ return hypos
208
+
209
+
210
+ FairseqLMState = namedtuple("FairseqLMState", ["prefix", "incremental_state", "probs"])
211
+
212
+
213
+ class FairseqLM(LM):
214
+ def __init__(self, dictionary, model):
215
+ LM.__init__(self)
216
+ self.dictionary = dictionary
217
+ self.model = model
218
+ self.unk = self.dictionary.unk()
219
+
220
+ self.save_incremental = False # this currently does not work properly
221
+ self.max_cache = 20_000
222
+
223
+ model.cuda()
224
+ model.eval()
225
+ model.make_generation_fast_()
226
+
227
+ self.states = {}
228
+ self.stateq = deque()
229
+
230
+ def start(self, start_with_nothing):
231
+ state = LMState()
232
+ prefix = torch.LongTensor([[self.dictionary.eos()]])
233
+ incremental_state = {} if self.save_incremental else None
234
+ with torch.no_grad():
235
+ res = self.model(prefix.cuda(), incremental_state=incremental_state)
236
+ probs = self.model.get_normalized_probs(res, log_probs=True, sample=None)
237
+
238
+ if incremental_state is not None:
239
+ incremental_state = apply_to_sample(lambda x: x.cpu(), incremental_state)
240
+ self.states[state] = FairseqLMState(
241
+ prefix.numpy(), incremental_state, probs[0, -1].cpu().numpy()
242
+ )
243
+ self.stateq.append(state)
244
+
245
+ return state
246
+
247
+ def score(self, state: LMState, token_index: int, no_cache: bool = False):
248
+ """
249
+ Evaluate language model based on the current lm state and new word
250
+ Parameters:
251
+ -----------
252
+ state: current lm state
253
+ token_index: index of the word
254
+ (can be lexicon index then you should store inside LM the
255
+ mapping between indices of lexicon and lm, or lm index of a word)
256
+
257
+ Returns:
258
+ --------
259
+ (LMState, float): pair of (new state, score for the current word)
260
+ """
261
+ curr_state = self.states[state]
262
+
263
+ def trim_cache(targ_size):
264
+ while len(self.stateq) > targ_size:
265
+ rem_k = self.stateq.popleft()
266
+ rem_st = self.states[rem_k]
267
+ rem_st = FairseqLMState(rem_st.prefix, None, None)
268
+ self.states[rem_k] = rem_st
269
+
270
+ if curr_state.probs is None:
271
+ new_incremental_state = (
272
+ curr_state.incremental_state.copy()
273
+ if curr_state.incremental_state is not None
274
+ else None
275
+ )
276
+ with torch.no_grad():
277
+ if new_incremental_state is not None:
278
+ new_incremental_state = apply_to_sample(
279
+ lambda x: x.cuda(), new_incremental_state
280
+ )
281
+ elif self.save_incremental:
282
+ new_incremental_state = {}
283
+
284
+ res = self.model(
285
+ torch.from_numpy(curr_state.prefix).cuda(),
286
+ incremental_state=new_incremental_state,
287
+ )
288
+ probs = self.model.get_normalized_probs(
289
+ res, log_probs=True, sample=None
290
+ )
291
+
292
+ if new_incremental_state is not None:
293
+ new_incremental_state = apply_to_sample(
294
+ lambda x: x.cpu(), new_incremental_state
295
+ )
296
+
297
+ curr_state = FairseqLMState(
298
+ curr_state.prefix, new_incremental_state, probs[0, -1].cpu().numpy()
299
+ )
300
+
301
+ if not no_cache:
302
+ self.states[state] = curr_state
303
+ self.stateq.append(state)
304
+
305
+ score = curr_state.probs[token_index].item()
306
+
307
+ trim_cache(self.max_cache)
308
+
309
+ outstate = state.child(token_index)
310
+ if outstate not in self.states and not no_cache:
311
+ prefix = np.concatenate(
312
+ [curr_state.prefix, torch.LongTensor([[token_index]])], -1
313
+ )
314
+ incr_state = curr_state.incremental_state
315
+
316
+ self.states[outstate] = FairseqLMState(prefix, incr_state, None)
317
+
318
+ if token_index == self.unk:
319
+ score = float("-inf")
320
+
321
+ return outstate, score
322
+
323
+ def finish(self, state: LMState):
324
+ """
325
+ Evaluate eos for language model based on the current lm state
326
+
327
+ Returns:
328
+ --------
329
+ (LMState, float): pair of (new state, score for the current word)
330
+ """
331
+ return self.score(state, self.dictionary.eos())
332
+
333
+ def empty_cache(self):
334
+ self.states = {}
335
+ self.stateq = deque()
336
+ gc.collect()
337
+
338
+
339
+ class W2lFairseqLMDecoder(W2lDecoder):
340
+ def __init__(self, args, tgt_dict):
341
+ super().__init__(args, tgt_dict)
342
+
343
+ self.silence = tgt_dict.bos()
344
+
345
+ self.unit_lm = getattr(args, "unit_lm", False)
346
+
347
+ self.lexicon = load_words(args.lexicon) if args.lexicon else None
348
+ self.idx_to_wrd = {}
349
+
350
+ checkpoint = torch.load(args.kenlm_model, map_location="cpu")
351
+ lm_args = checkpoint["args"]
352
+ lm_args.data = osp.dirname(args.kenlm_model)
353
+ print(lm_args)
354
+ task = tasks.setup_task(lm_args)
355
+ model = task.build_model(lm_args)
356
+ model.load_state_dict(checkpoint["model"], strict=False)
357
+
358
+ self.trie = Trie(self.vocab_size, self.silence)
359
+
360
+ self.word_dict = task.dictionary
361
+ self.unk_word = self.word_dict.unk()
362
+ self.lm = FairseqLM(self.word_dict, model)
363
+
364
+ self.decoder_opts = DecoderOptions(
365
+ args.beam,
366
+ int(getattr(args, "beam_size_token", len(tgt_dict))),
367
+ args.beam_threshold,
368
+ args.lm_weight,
369
+ args.word_score,
370
+ args.unk_weight,
371
+ args.sil_weight,
372
+ 0,
373
+ False,
374
+ self.criterion_type,
375
+ )
376
+
377
+ if self.lexicon:
378
+ start_state = self.lm.start(False)
379
+ for i, (word, spellings) in enumerate(self.lexicon.items()):
380
+ if self.unit_lm:
381
+ word_idx = i
382
+ self.idx_to_wrd[i] = word
383
+ score = 0
384
+ else:
385
+ word_idx = self.word_dict.index(word)
386
+ _, score = self.lm.score(start_state, word_idx, no_cache=True)
387
+
388
+ for spelling in spellings:
389
+ spelling_idxs = [tgt_dict.index(token) for token in spelling]
390
+ assert (
391
+ tgt_dict.unk() not in spelling_idxs
392
+ ), f"{spelling} {spelling_idxs}"
393
+ self.trie.insert(spelling_idxs, word_idx, score)
394
+ self.trie.smear(SmearingMode.MAX)
395
+
396
+ self.decoder = LexiconDecoder(
397
+ self.decoder_opts,
398
+ self.trie,
399
+ self.lm,
400
+ self.silence,
401
+ self.blank,
402
+ self.unk_word,
403
+ [],
404
+ self.unit_lm,
405
+ )
406
+ else:
407
+ self.decoder = LexiconFreeDecoder(
408
+ self.decoder_opts, self.lm, self.silence, self.blank, []
409
+ )
410
+
411
+ def decode(self, emissions):
412
+ B, T, N = emissions.size()
413
+ hypos = []
414
+
415
+ def idx_to_word(idx):
416
+ if self.unit_lm:
417
+ return self.idx_to_wrd[idx]
418
+ else:
419
+ return self.word_dict[idx]
420
+
421
+ def make_hypo(result):
422
+ hypo = {"tokens": self.get_tokens(result.tokens), "score": result.score}
423
+ if self.lexicon:
424
+ hypo["words"] = [idx_to_word(x) for x in result.words if x >= 0]
425
+ return hypo
426
+
427
+ for b in range(B):
428
+ emissions_ptr = emissions.data_ptr() + 4 * b * emissions.stride(0)
429
+ results = self.decoder.decode(emissions_ptr, T, N)
430
+
431
+ nbest_results = results[: self.nbest]
432
+ hypos.append([make_hypo(result) for result in nbest_results])
433
+ self.lm.empty_cache()
434
+
435
+ return hypos
fairseq-0.10.2/examples/speech_to_text/data_utils.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import os
9
+ import os.path as op
10
+ import zipfile
11
+ from functools import reduce
12
+ from glob import glob
13
+ from multiprocessing import cpu_count
14
+ from typing import Any, Dict, List
15
+
16
+ import numpy as np
17
+ import sentencepiece as sp
18
+ from fairseq.data.audio.audio_utils import _get_kaldi_fbank, _get_torchaudio_fbank
19
+ from fairseq.data.audio.feature_transforms.utterance_cmvn import UtteranceCMVN
20
+ from tqdm import tqdm
21
+
22
+
23
+ UNK_TOKEN, UNK_TOKEN_ID = "<unk>", 3
24
+ BOS_TOKEN, BOS_TOKEN_ID = "<s>", 0
25
+ EOS_TOKEN, EOS_TOKEN_ID = "</s>", 2
26
+ PAD_TOKEN, PAD_TOKEN_ID = "<pad>", 1
27
+
28
+
29
+ def gen_vocab(
30
+ input_path: str,
31
+ output_path_prefix: str,
32
+ model_type="bpe",
33
+ vocab_size=1000,
34
+ ):
35
+ # Train SentencePiece Model
36
+ arguments = [
37
+ f"--input={input_path}",
38
+ f"--model_prefix={output_path_prefix}",
39
+ f"--model_type={model_type}",
40
+ f"--vocab_size={vocab_size}",
41
+ "--character_coverage=1.0",
42
+ f"--num_threads={cpu_count()}",
43
+ f"--unk_id={UNK_TOKEN_ID}",
44
+ f"--bos_id={BOS_TOKEN_ID}",
45
+ f"--eos_id={EOS_TOKEN_ID}",
46
+ f"--pad_id={PAD_TOKEN_ID}",
47
+ ]
48
+ sp.SentencePieceTrainer.Train(" ".join(arguments))
49
+ # Export fairseq dictionary
50
+ spm = sp.SentencePieceProcessor()
51
+ spm.Load(output_path_prefix + ".model")
52
+ vocab = {i: spm.IdToPiece(i) for i in range(spm.GetPieceSize())}
53
+ assert (
54
+ vocab.get(UNK_TOKEN_ID) == UNK_TOKEN
55
+ and vocab.get(PAD_TOKEN_ID) == PAD_TOKEN
56
+ and vocab.get(BOS_TOKEN_ID) == BOS_TOKEN
57
+ and vocab.get(EOS_TOKEN_ID) == EOS_TOKEN
58
+ )
59
+ vocab = {
60
+ i: s
61
+ for i, s in vocab.items()
62
+ if s not in {UNK_TOKEN, BOS_TOKEN, EOS_TOKEN, PAD_TOKEN}
63
+ }
64
+ with open(output_path_prefix + ".txt", "w") as f_out:
65
+ for _, s in sorted(vocab.items(), key=lambda x: x[0]):
66
+ f_out.write(f"{s} 1\n")
67
+
68
+
69
+ def extract_fbank_features(
70
+ waveform,
71
+ sample_rate,
72
+ output_path=None,
73
+ n_mel_bins=80,
74
+ apply_utterance_cmvn=True,
75
+ overwrite=False,
76
+ ):
77
+ if output_path is not None and op.exists(output_path) and not overwrite:
78
+ return
79
+
80
+ _waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers
81
+ _waveform = _waveform.squeeze().numpy()
82
+
83
+ features = _get_kaldi_fbank(_waveform, sample_rate, n_mel_bins)
84
+ if features is None:
85
+ features = _get_torchaudio_fbank(_waveform, sample_rate, n_mel_bins)
86
+ if features is None:
87
+ raise ImportError(
88
+ "Please install pyKaldi or torchaudio to enable "
89
+ "online filterbank feature extraction"
90
+ )
91
+
92
+ if apply_utterance_cmvn:
93
+ cmvn = UtteranceCMVN(norm_means=True, norm_vars=True)
94
+ features = cmvn(features)
95
+ if output_path is not None:
96
+ np.save(output_path, features)
97
+ else:
98
+ return features
99
+
100
+
101
+ def create_zip(data_root, zip_path):
102
+ cwd = os.path.abspath(os.curdir)
103
+ os.chdir(data_root)
104
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_STORED) as f:
105
+ for filename in tqdm(glob("*.npy")):
106
+ f.write(filename)
107
+ os.chdir(cwd)
108
+
109
+
110
+ def is_npy_data(data: bytes) -> bool:
111
+ return data[0] == 147 and data[1] == 78
112
+
113
+
114
+ def get_zip_manifest(zip_root, zip_filename):
115
+ zip_path = op.join(zip_root, zip_filename)
116
+ with zipfile.ZipFile(zip_path, mode="r") as f:
117
+ info = f.infolist()
118
+ manifest = {}
119
+ for i in tqdm(info):
120
+ utt_id = op.splitext(i.filename)[0]
121
+ offset, file_size = i.header_offset + 30 + len(i.filename), i.file_size
122
+ manifest[utt_id] = f"{zip_filename}:{offset}:{file_size}"
123
+ with open(zip_path, "rb") as f:
124
+ f.seek(offset)
125
+ data = f.read(file_size)
126
+ assert len(data) > 1 and is_npy_data(data)
127
+ return manifest
128
+
129
+
130
+ def gen_config_yaml(
131
+ data_root, spm_filename, yaml_filename="config.yaml", specaugment_policy="lb"
132
+ ):
133
+ assert specaugment_policy in {"lb", "ld"}
134
+ data_root = op.abspath(data_root)
135
+ writer = S2TDataConfigWriter(op.join(data_root, yaml_filename))
136
+ writer.set_audio_root(op.abspath(data_root))
137
+ writer.set_vocab_filename(spm_filename.replace(".model", ".txt"))
138
+ writer.set_input_channels(1)
139
+ writer.set_input_feat_per_channel(80)
140
+ if specaugment_policy == "lb":
141
+ writer.set_specaugment_lb_policy()
142
+ else:
143
+ writer.set_specaugment_ld_policy()
144
+ writer.set_bpe_tokenizer(
145
+ {
146
+ "bpe": "sentencepiece",
147
+ "sentencepiece_model": op.join(data_root, spm_filename),
148
+ }
149
+ )
150
+ writer.set_feature_transforms("_train", ["specaugment"])
151
+ writer.flush()
152
+
153
+
154
+ def save_df_to_tsv(dataframe, path):
155
+ dataframe.to_csv(
156
+ path,
157
+ sep="\t",
158
+ header=True,
159
+ index=False,
160
+ encoding="utf-8",
161
+ escapechar="\\",
162
+ quoting=csv.QUOTE_NONE,
163
+ )
164
+
165
+
166
+ def filter_manifest_df(
167
+ df, is_train_split=False, extra_filters=None, min_n_frames=5, max_n_frames=3000
168
+ ):
169
+ filters = {
170
+ "no speech": df["audio"] == "",
171
+ f"short speech (<{min_n_frames} frames)": df["n_frames"] < min_n_frames,
172
+ "empty sentence": df["tgt_text"] == "",
173
+ }
174
+ if is_train_split:
175
+ filters[f"long speech (>{max_n_frames} frames)"] = df["n_frames"] > max_n_frames
176
+ if extra_filters is not None:
177
+ filters.update(extra_filters)
178
+ invalid = reduce(lambda x, y: x | y, filters.values())
179
+ valid = ~invalid
180
+ print(
181
+ "| "
182
+ + ", ".join(f"{n}: {f.sum()}" for n, f in filters.items())
183
+ + f", total {invalid.sum()} filtered, {valid.sum()} remained."
184
+ )
185
+ return df[valid]
186
+
187
+
188
+ class S2TDataConfigWriter(object):
189
+ DEFAULT_VOCAB_FILENAME = "dict.txt"
190
+ DEFAULT_INPUT_FEAT_PER_CHANNEL = 80
191
+ DEFAULT_INPUT_CHANNELS = 1
192
+
193
+ def __init__(self, yaml_path):
194
+ try:
195
+ import yaml
196
+ except ImportError:
197
+ print("Please install PyYAML to load YAML files for S2T data config")
198
+ self.yaml = yaml
199
+ self.yaml_path = yaml_path
200
+ self.config = {}
201
+
202
+ def flush(self):
203
+ with open(self.yaml_path, "w") as f:
204
+ self.yaml.dump(self.config, f)
205
+
206
+ def set_audio_root(self, audio_root=""):
207
+ self.config["audio_root"] = audio_root
208
+
209
+ def set_vocab_filename(self, vocab_filename="dict.txt"):
210
+ self.config["vocab_filename"] = vocab_filename
211
+
212
+ def set_specaugment(
213
+ self,
214
+ time_wrap_w: int,
215
+ freq_mask_n: int,
216
+ freq_mask_f: int,
217
+ time_mask_n: int,
218
+ time_mask_t: int,
219
+ time_mask_p: float,
220
+ ):
221
+ self.config["specaugment"] = {
222
+ "time_wrap_W": time_wrap_w,
223
+ "freq_mask_N": freq_mask_n,
224
+ "freq_mask_F": freq_mask_f,
225
+ "time_mask_N": time_mask_n,
226
+ "time_mask_T": time_mask_t,
227
+ "time_mask_p": time_mask_p,
228
+ }
229
+
230
+ def set_specaugment_lb_policy(self):
231
+ self.set_specaugment(
232
+ time_wrap_w=0,
233
+ freq_mask_n=1,
234
+ freq_mask_f=27,
235
+ time_mask_n=1,
236
+ time_mask_t=100,
237
+ time_mask_p=1.0,
238
+ )
239
+
240
+ def set_specaugment_ld_policy(self):
241
+ self.set_specaugment(
242
+ time_wrap_w=0,
243
+ freq_mask_n=2,
244
+ freq_mask_f=27,
245
+ time_mask_n=2,
246
+ time_mask_t=100,
247
+ time_mask_p=1.0,
248
+ )
249
+
250
+ def set_input_channels(self, input_channels=1):
251
+ self.config["input_channels"] = input_channels
252
+
253
+ def set_input_feat_per_channel(self, input_feat_per_channel=80):
254
+ self.config["input_feat_per_channel"] = input_feat_per_channel
255
+
256
+ def set_bpe_tokenizer(self, bpe_tokenizer: Dict[str, Any]):
257
+ self.config["bpe_tokenizer"] = bpe_tokenizer
258
+
259
+ def set_feature_transforms(self, split, transforms: List[str]):
260
+ if "transforms" not in self.config:
261
+ self.config["transforms"] = {}
262
+ self.config["transforms"][split] = transforms
fairseq-0.10.2/examples/speech_to_text/prep_covost_data.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import csv
9
+ import logging
10
+ import os
11
+ import os.path as op
12
+ import shutil
13
+ from tempfile import NamedTemporaryFile
14
+ from typing import Optional, Tuple
15
+
16
+ import pandas as pd
17
+ import torchaudio
18
+ from examples.speech_to_text.data_utils import (
19
+ create_zip,
20
+ extract_fbank_features,
21
+ filter_manifest_df,
22
+ gen_config_yaml,
23
+ gen_vocab,
24
+ get_zip_manifest,
25
+ save_df_to_tsv,
26
+ )
27
+ from torch import Tensor
28
+ from torch.utils.data import Dataset
29
+ from torchaudio.datasets.utils import download_url, extract_archive
30
+ from tqdm import tqdm
31
+
32
+
33
+ log = logging.getLogger(__name__)
34
+
35
+
36
+ MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
37
+
38
+
39
+ class CoVoST(Dataset):
40
+ """Create a Dataset for CoVoST (https://github.com/facebookresearch/covost).
41
+
42
+ Args:
43
+ root (str): root path to the dataset and generated manifests/features
44
+ source_language (str): source (audio) language
45
+ target_language (str, optional): target (text) language,
46
+ None for no translation (default: None)
47
+ version (int, optional): CoVoST version. (default: 2)
48
+ download (bool, optional): Whether to download the dataset if it is not
49
+ found at root path. (default: ``False``).
50
+ """
51
+
52
+ CV_URL_TEMPLATE = (
53
+ "https://voice-prod-bundler-ee1969a6ce8178826482b88"
54
+ "e843c335139bd3fb4.s3.amazonaws.com/{ver}/{lang}.tar.gz"
55
+ )
56
+ COVOST_URL_TEMPLATE = (
57
+ "https://dl.fbaipublicfiles.com/covost/"
58
+ "covost_v2.{src_lang}_{tgt_lang}.tsv.tar.gz"
59
+ )
60
+
61
+ VERSIONS = {2}
62
+ SPLITS = ["train", "dev", "test"]
63
+
64
+ CV_VERSION_ID = {1: "cv-corpus-3", 2: "cv-corpus-4-2019-12-10"}
65
+
66
+ XX_EN_LANGUAGES = {
67
+ 1: ["fr", "de", "nl", "ru", "es", "it", "tr", "fa", "sv-SE", "mn", "zh-CN"],
68
+ 2: [
69
+ "fr",
70
+ "de",
71
+ "es",
72
+ "ca",
73
+ "it",
74
+ "ru",
75
+ "zh-CN",
76
+ "pt",
77
+ "fa",
78
+ "et",
79
+ "mn",
80
+ "nl",
81
+ "tr",
82
+ "ar",
83
+ "sv-SE",
84
+ "lv",
85
+ "sl",
86
+ "ta",
87
+ "ja",
88
+ "id",
89
+ "cy",
90
+ ],
91
+ }
92
+ EN_XX_LANGUAGES = {
93
+ 1: [],
94
+ 2: [
95
+ "de",
96
+ "tr",
97
+ "fa",
98
+ "sv-SE",
99
+ "mn",
100
+ "zh-CN",
101
+ "cy",
102
+ "ca",
103
+ "sl",
104
+ "et",
105
+ "id",
106
+ "ar",
107
+ "ta",
108
+ "lv",
109
+ "ja",
110
+ ],
111
+ }
112
+
113
+ def __init__(
114
+ self,
115
+ root: str,
116
+ split: str,
117
+ source_language: str,
118
+ target_language: Optional[str] = None,
119
+ version: int = 2,
120
+ download: bool = False,
121
+ ) -> None:
122
+ assert version in self.VERSIONS and split in self.SPLITS
123
+ assert source_language is not None
124
+ self.no_translation = target_language is None
125
+ if not self.no_translation:
126
+ assert "en" in {source_language, target_language}
127
+ if source_language == "en":
128
+ assert target_language in self.EN_XX_LANGUAGES[version]
129
+ else:
130
+ assert source_language in self.XX_EN_LANGUAGES[version]
131
+ else:
132
+ # Hack here so that we can get "split" column from CoVoST TSV.
133
+ # Note that we use CoVoST train split for ASR which is an extension
134
+ # to Common Voice train split.
135
+ target_language = "de" if source_language == "en" else "en"
136
+
137
+ self.root = os.path.join(root, "raw")
138
+ os.makedirs(self.root, exist_ok=True)
139
+
140
+ cv_url = self.CV_URL_TEMPLATE.format(
141
+ ver=self.CV_VERSION_ID[version], lang=source_language
142
+ )
143
+ cv_archive = os.path.join(self.root, os.path.basename(cv_url))
144
+ if download:
145
+ if not os.path.isfile(cv_archive):
146
+ download_url(cv_url, self.root, hash_value=None)
147
+ extract_archive(cv_archive)
148
+
149
+ covost_url = self.COVOST_URL_TEMPLATE.format(
150
+ src_lang=source_language, tgt_lang=target_language
151
+ )
152
+ covost_archive = os.path.join(self.root, os.path.basename(covost_url))
153
+ if download:
154
+ if not os.path.isfile(covost_archive):
155
+ download_url(covost_url, self.root, hash_value=None)
156
+ extract_archive(covost_archive)
157
+
158
+ cv_tsv = self.load_from_tsv(os.path.join(self.root, "validated.tsv"))
159
+ covost_tsv = self.load_from_tsv(
160
+ os.path.join(self.root, os.path.basename(covost_url).replace(".tar.gz", ""))
161
+ )
162
+ df = pd.merge(
163
+ left=cv_tsv[["path", "sentence", "client_id"]],
164
+ right=covost_tsv[["path", "translation", "split"]],
165
+ how="inner",
166
+ on="path",
167
+ )
168
+ if split == "train":
169
+ df = df[(df["split"] == split) | (df["split"] == f"{split}_covost")]
170
+ else:
171
+ df = df[df["split"] == split]
172
+ self.data = df.to_dict(orient="index").items()
173
+ self.data = [v for k, v in sorted(self.data, key=lambda x: x[0])]
174
+
175
+ @classmethod
176
+ def load_from_tsv(cls, path: str):
177
+ return pd.read_csv(
178
+ path,
179
+ sep="\t",
180
+ header=0,
181
+ encoding="utf-8",
182
+ escapechar="\\",
183
+ quoting=csv.QUOTE_NONE,
184
+ na_filter=False,
185
+ )
186
+
187
+ def __getitem__(
188
+ self, n: int
189
+ ) -> Tuple[Tensor, int, str, str, Optional[str], str, str]:
190
+ """Load the n-th sample from the dataset.
191
+
192
+ Args:
193
+ n (int): The index of the sample to be loaded
194
+
195
+ Returns:
196
+ tuple: ``(waveform, sample_rate, sentence, translation, speaker_id,
197
+ sample_id)``
198
+ """
199
+ data = self.data[n]
200
+ path = os.path.join(self.root, "clips", data["path"])
201
+ waveform, sample_rate = torchaudio.load(path)
202
+ sentence = data["sentence"]
203
+ translation = None if self.no_translation else data["translation"]
204
+ speaker_id = data["client_id"]
205
+ _id = data["path"].replace(".mp3", "")
206
+ return waveform, sample_rate, sentence, translation, speaker_id, _id
207
+
208
+ def __len__(self) -> int:
209
+ return len(self.data)
210
+
211
+
212
+ def process(args):
213
+ root = op.join(args.data_root, args.src_lang)
214
+ os.makedirs(root, exist_ok=True)
215
+ # Extract features
216
+ feature_root = op.join(root, "fbank80")
217
+ os.makedirs(feature_root, exist_ok=True)
218
+ for split in CoVoST.SPLITS:
219
+ print(f"Fetching split {split}...")
220
+ dataset = CoVoST(root, split, args.src_lang, args.tgt_lang, download=True)
221
+ print("Extracting log mel filter bank features...")
222
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
223
+ extract_fbank_features(
224
+ waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
225
+ )
226
+ # Pack features into ZIP
227
+ zip_filename = "fbank80.zip"
228
+ zip_path = op.join(root, zip_filename)
229
+ print("ZIPing features...")
230
+ create_zip(feature_root, zip_path)
231
+ print("Fetching ZIP manifest...")
232
+ zip_manifest = get_zip_manifest(args.data_root, f"{args.src_lang}/{zip_filename}")
233
+ # Generate TSV manifest
234
+ print("Generating manifest...")
235
+ train_text = []
236
+ task = f"asr_{args.src_lang}"
237
+ if args.tgt_lang is not None:
238
+ task = f"st_{args.src_lang}_{args.tgt_lang}"
239
+ for split in CoVoST.SPLITS:
240
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
241
+ dataset = CoVoST(root, split, args.src_lang, args.tgt_lang)
242
+ for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
243
+ manifest["id"].append(utt_id)
244
+ manifest["audio"].append(zip_manifest[utt_id])
245
+ duration_ms = int(wav.size(1) / sr * 1000)
246
+ manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
247
+ manifest["tgt_text"].append(src_utt if args.tgt_lang is None else tgt_utt)
248
+ manifest["speaker"].append(speaker_id)
249
+ is_train_split = split.startswith("train")
250
+ if is_train_split:
251
+ train_text.extend(manifest["tgt_text"])
252
+ df = pd.DataFrame.from_dict(manifest)
253
+ df = filter_manifest_df(df, is_train_split=is_train_split)
254
+ save_df_to_tsv(df, op.join(root, f"{split}_{task}.tsv"))
255
+ # Generate vocab
256
+ vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size)
257
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{task}"
258
+ with NamedTemporaryFile(mode="w") as f:
259
+ for t in train_text:
260
+ f.write(t + "\n")
261
+ gen_vocab(
262
+ f.name, op.join(root, spm_filename_prefix), args.vocab_type, args.vocab_size
263
+ )
264
+ # Generate config YAML
265
+ gen_config_yaml(
266
+ root,
267
+ spm_filename_prefix + ".model",
268
+ yaml_filename=f"config_{task}.yaml",
269
+ specaugment_policy="lb",
270
+ )
271
+ # Clean up
272
+ shutil.rmtree(feature_root)
273
+
274
+
275
+ def main():
276
+ parser = argparse.ArgumentParser()
277
+ parser.add_argument("--data-root", "-d", required=True, type=str)
278
+ parser.add_argument(
279
+ "--vocab-type",
280
+ default="unigram",
281
+ required=True,
282
+ type=str,
283
+ choices=["bpe", "unigram", "char"],
284
+ ),
285
+ parser.add_argument("--vocab-size", default=1000, type=int)
286
+ parser.add_argument("--src-lang", "-s", required=True, type=str)
287
+ parser.add_argument("--tgt-lang", "-t", type=str)
288
+ args = parser.parse_args()
289
+
290
+ process(args)
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()
fairseq-0.10.2/examples/speech_to_text/prep_librispeech_data.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import os.path as op
11
+ import shutil
12
+ from tempfile import NamedTemporaryFile
13
+
14
+ import pandas as pd
15
+ from examples.speech_to_text.data_utils import (
16
+ create_zip,
17
+ extract_fbank_features,
18
+ gen_config_yaml,
19
+ gen_vocab,
20
+ get_zip_manifest,
21
+ save_df_to_tsv,
22
+ )
23
+ from torchaudio.datasets import LIBRISPEECH
24
+ from tqdm import tqdm
25
+
26
+
27
+ log = logging.getLogger(__name__)
28
+
29
+ SPLITS = [
30
+ "train-clean-100",
31
+ "train-clean-360",
32
+ "train-other-500",
33
+ "dev-clean",
34
+ "dev-other",
35
+ "test-clean",
36
+ "test-other",
37
+ ]
38
+
39
+ MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
40
+
41
+
42
+ def process(args):
43
+ os.makedirs(args.output_root, exist_ok=True)
44
+ # Extract features
45
+ feature_root = op.join(args.output_root, "fbank80")
46
+ os.makedirs(feature_root, exist_ok=True)
47
+ for split in SPLITS:
48
+ print(f"Fetching split {split}...")
49
+ dataset = LIBRISPEECH(args.output_root, url=split, download=True)
50
+ print("Extracting log mel filter bank features...")
51
+ for wav, sample_rate, _, spk_id, chapter_id, utt_id in tqdm(dataset):
52
+ sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
53
+ extract_fbank_features(
54
+ wav, sample_rate, op.join(feature_root, f"{sample_id}.npy")
55
+ )
56
+ # Pack features into ZIP
57
+ zip_filename = "fbank80.zip"
58
+ zip_path = op.join(args.output_root, zip_filename)
59
+ print("ZIPing features...")
60
+ create_zip(feature_root, zip_path)
61
+ print("Fetching ZIP manifest...")
62
+ zip_manifest = get_zip_manifest(args.output_root, zip_filename)
63
+ # Generate TSV manifest
64
+ print("Generating manifest...")
65
+ train_text = []
66
+ for split in SPLITS:
67
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
68
+ dataset = LIBRISPEECH(args.output_root, url=split)
69
+ for wav, sample_rate, utt, spk_id, chapter_id, utt_id in tqdm(dataset):
70
+ sample_id = f"{spk_id}-{chapter_id}-{utt_id}"
71
+ manifest["id"].append(sample_id)
72
+ manifest["audio"].append(zip_manifest[sample_id])
73
+ duration_ms = int(wav.size(1) / sample_rate * 1000)
74
+ manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
75
+ manifest["tgt_text"].append(utt)
76
+ manifest["speaker"].append(spk_id)
77
+ save_df_to_tsv(
78
+ pd.DataFrame.from_dict(manifest), op.join(args.output_root, f"{split}.tsv")
79
+ )
80
+ if split.startswith("train"):
81
+ train_text.extend(manifest["tgt_text"])
82
+ # Generate vocab
83
+ vocab_size = "" if args.vocab_type == "char" else str(args.vocab_size)
84
+ spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size}"
85
+ with NamedTemporaryFile(mode="w") as f:
86
+ for t in train_text:
87
+ f.write(t + "\n")
88
+ gen_vocab(
89
+ f.name,
90
+ op.join(args.output_root, spm_filename_prefix),
91
+ args.vocab_type,
92
+ args.vocab_size,
93
+ )
94
+ # Generate config YAML
95
+ gen_config_yaml(
96
+ args.output_root, spm_filename_prefix + ".model", specaugment_policy="ld"
97
+ )
98
+ # Clean up
99
+ shutil.rmtree(feature_root)
100
+
101
+
102
+ def main():
103
+ parser = argparse.ArgumentParser()
104
+ parser.add_argument("--output-root", "-o", required=True, type=str)
105
+ parser.add_argument(
106
+ "--vocab-type",
107
+ default="unigram",
108
+ required=True,
109
+ type=str,
110
+ choices=["bpe", "unigram", "char"],
111
+ ),
112
+ parser.add_argument("--vocab-size", default=10000, type=int)
113
+ args = parser.parse_args()
114
+
115
+ process(args)
116
+
117
+
118
+ if __name__ == "__main__":
119
+ main()
fairseq-0.10.2/examples/speech_to_text/prep_mustc_data.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import logging
9
+ import os
10
+ import os.path as op
11
+ import shutil
12
+ from itertools import groupby
13
+ from tempfile import NamedTemporaryFile
14
+ from typing import Tuple
15
+
16
+ import pandas as pd
17
+ import torchaudio
18
+ from examples.speech_to_text.data_utils import (
19
+ create_zip,
20
+ extract_fbank_features,
21
+ filter_manifest_df,
22
+ gen_config_yaml,
23
+ gen_vocab,
24
+ get_zip_manifest,
25
+ save_df_to_tsv,
26
+ )
27
+ from torch import Tensor
28
+ from torch.utils.data import Dataset
29
+ from tqdm import tqdm
30
+
31
+
32
+ log = logging.getLogger(__name__)
33
+
34
+
35
+ MANIFEST_COLUMNS = ["id", "audio", "n_frames", "tgt_text", "speaker"]
36
+ TASKS = ["asr", "st"]
37
+
38
+
39
+ class MUSTC(Dataset):
40
+ """
41
+ Create a Dataset for MuST-C. Each item is a tuple of the form:
42
+ waveform, sample_rate, source utterance, target utterance, speaker_id,
43
+ utterance_id
44
+ """
45
+
46
+ SPLITS = ["train", "dev", "tst-COMMON", "tst-HE"]
47
+ LANGUAGES = ["de", "es", "fr", "it", "nl", "pt", "ro", "ru"]
48
+
49
+ def __init__(self, root: str, lang: str, split: str) -> None:
50
+ assert split in self.SPLITS and lang in self.LANGUAGES
51
+ _root = op.join(root, f"en-{lang}", "data", split)
52
+ wav_root, txt_root = op.join(_root, "wav"), op.join(_root, "txt")
53
+ assert op.isdir(_root) and op.isdir(wav_root) and op.isdir(txt_root)
54
+ # Load audio segments
55
+ try:
56
+ import yaml
57
+ except ImportError:
58
+ print("Please install PyYAML to load YAML files for " "the MuST-C dataset")
59
+ with open(op.join(txt_root, f"{split}.yaml")) as f:
60
+ segments = yaml.load(f, Loader=yaml.BaseLoader)
61
+ # Load source and target utterances
62
+ for _lang in ["en", lang]:
63
+ with open(op.join(txt_root, f"{split}.{_lang}")) as f:
64
+ utterances = [r.strip() for r in f]
65
+ assert len(segments) == len(utterances)
66
+ for i, u in enumerate(utterances):
67
+ segments[i][_lang] = u
68
+ # Gather info
69
+ self.data = []
70
+ for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]):
71
+ wav_path = op.join(wav_root, wav_filename)
72
+ sample_rate = torchaudio.info(wav_path)[0].rate
73
+ seg_group = sorted(_seg_group, key=lambda x: x["offset"])
74
+ for i, segment in enumerate(seg_group):
75
+ offset = int(float(segment["offset"]) * sample_rate)
76
+ n_frames = int(float(segment["duration"]) * sample_rate)
77
+ _id = f"{op.splitext(wav_filename)[0]}_{i}"
78
+ self.data.append(
79
+ (
80
+ wav_path,
81
+ offset,
82
+ n_frames,
83
+ sample_rate,
84
+ segment["en"],
85
+ segment[lang],
86
+ segment["speaker_id"],
87
+ _id,
88
+ )
89
+ )
90
+
91
+ def __getitem__(self, n: int) -> Tuple[Tensor, int, str, str, str, str]:
92
+ wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, utt_id = self.data[n]
93
+ waveform, _ = torchaudio.load(wav_path, offset=offset, num_frames=n_frames)
94
+ return waveform, sr, src_utt, tgt_utt, spk_id, utt_id
95
+
96
+ def __len__(self) -> int:
97
+ return len(self.data)
98
+
99
+
100
+ def process(args):
101
+ for lang in MUSTC.LANGUAGES:
102
+ cur_root = op.join(args.data_root, f"en-{lang}")
103
+ if not op.isdir(cur_root):
104
+ print(f"{cur_root} does not exist. Skipped.")
105
+ continue
106
+ # Extract features
107
+ feature_root = op.join(cur_root, "fbank80")
108
+ os.makedirs(feature_root, exist_ok=True)
109
+ for split in MUSTC.SPLITS:
110
+ print(f"Fetching split {split}...")
111
+ dataset = MUSTC(args.data_root, lang, split)
112
+ print("Extracting log mel filter bank features...")
113
+ for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset):
114
+ extract_fbank_features(
115
+ waveform, sample_rate, op.join(feature_root, f"{utt_id}.npy")
116
+ )
117
+ # Pack features into ZIP
118
+ zip_filename = "fbank80.zip"
119
+ zip_path = op.join(cur_root, zip_filename)
120
+ print("ZIPing features...")
121
+ create_zip(feature_root, zip_path)
122
+ print("Fetching ZIP manifest...")
123
+ zip_manifest = get_zip_manifest(args.data_root, f"en-{lang}/{zip_filename}")
124
+ # Generate TSV manifest
125
+ print("Generating manifest...")
126
+ train_text = {task: [] for task in TASKS}
127
+ for split in MUSTC.SPLITS:
128
+ is_train_split = split.startswith("train")
129
+ manifest = {c: [] for c in MANIFEST_COLUMNS}
130
+ text = {task: [] for task in TASKS}
131
+ dataset = MUSTC(args.data_root, lang, split)
132
+ for wav, sr, src_utt, tgt_utt, speaker_id, utt_id in tqdm(dataset):
133
+ manifest["id"].append(utt_id)
134
+ manifest["audio"].append(zip_manifest[utt_id])
135
+ duration_ms = int(wav.size(1) / sr * 1000)
136
+ manifest["n_frames"].append(int(1 + (duration_ms - 25) / 10))
137
+ text["asr"].append(src_utt)
138
+ text["st"].append(tgt_utt)
139
+ manifest["speaker"].append(speaker_id)
140
+ if is_train_split:
141
+ for task in TASKS:
142
+ train_text[task].extend(text[task])
143
+ for task in TASKS:
144
+ manifest["tgt_text"] = text[task]
145
+ df = pd.DataFrame.from_dict(manifest)
146
+ df = filter_manifest_df(df, is_train_split=is_train_split)
147
+ save_df_to_tsv(df, op.join(cur_root, f"{split}_{task}.tsv"))
148
+ # Generate vocab
149
+ for task in TASKS:
150
+ vocab_type, vocab_size = args.asr_vocab_type, args.asr_vocab_size
151
+ if task == "st":
152
+ vocab_type, vocab_size = args.st_vocab_type, args.st_vocab_size
153
+ vocab_size_str = "" if vocab_type == "char" else str(vocab_size)
154
+ spm_filename_prefix = f"spm_{vocab_type}{vocab_size_str}_{task}"
155
+ with NamedTemporaryFile(mode="w") as f:
156
+ for t in train_text[task]:
157
+ f.write(t + "\n")
158
+ gen_vocab(
159
+ f.name,
160
+ op.join(cur_root, spm_filename_prefix),
161
+ vocab_type,
162
+ vocab_size,
163
+ )
164
+ # Generate config YAML
165
+ gen_config_yaml(
166
+ cur_root,
167
+ spm_filename_prefix + ".model",
168
+ yaml_filename=f"config_{task}.yaml",
169
+ specaugment_policy="lb",
170
+ )
171
+ # Clean up
172
+ shutil.rmtree(feature_root)
173
+
174
+
175
+ def main():
176
+ parser = argparse.ArgumentParser()
177
+ parser.add_argument("--data-root", "-d", required=True, type=str)
178
+ parser.add_argument(
179
+ "--asr-vocab-type",
180
+ default="unigram",
181
+ required=True,
182
+ type=str,
183
+ choices=["bpe", "unigram", "char"],
184
+ ),
185
+ parser.add_argument(
186
+ "--st-vocab-type",
187
+ default="unigram",
188
+ required=True,
189
+ type=str,
190
+ choices=["bpe", "unigram", "char"],
191
+ ),
192
+ parser.add_argument("--asr-vocab-size", default=5000, type=int)
193
+ parser.add_argument("--st-vocab-size", default=8000, type=int)
194
+ args = parser.parse_args()
195
+
196
+ process(args)
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()
fairseq-0.10.2/examples/unsupervised_quality_estimation/README.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Unsupervised Quality Estimation for Neural Machine Translation (Fomicheva et al., 2020)
2
+
3
+ This page includes instructions for reproducing results from the paper [Unsupervised Quality Estimation for Neural
4
+ Machine Translation (Fomicheva et al., 2020)](https://arxiv.org/abs/2005.10608)
5
+
6
+ ## Requirements:
7
+
8
+ * mosesdecoder: https://github.com/moses-smt/mosesdecoder
9
+ * subword-nmt: https://github.com/rsennrich/subword-nmt
10
+ * flores: https://github.com/facebookresearch/flores
11
+
12
+ ## Download Models and Test Data
13
+
14
+ Download translation models and test data from [MLQE dataset repository](https://github.com/facebookresearch/mlqe).
15
+
16
+ ## Set up:
17
+
18
+ Given a testset consisting of source sentences and reference translations:
19
+
20
+ * `SRC_LANG`: source language
21
+ * `TGT_LANG`: target language
22
+ * `INPUT`: input prefix, such that the file `$INPUT.$SRC_LANG` contains source sentences and `$INPUT.$TGT_LANG`
23
+ contains the reference sentences
24
+ * `OUTPUT_DIR`: output path to store results
25
+ * `MOSES_DECODER`: path to mosesdecoder installation
26
+ * `BPE_ROOT`: path to subword-nmt installation
27
+ * `BPE`: path to BPE model
28
+ * `MODEL_DIR`: directory containing the NMT model `.pt` file as well as the source and target vocabularies.
29
+ * `TMP`: directory for intermediate temporary files
30
+ * `GPU`: if translating with GPU, id of the GPU to use for inference
31
+ * `DROPOUT_N`: number of stochastic forward passes
32
+
33
+ `$DROPOUT_N` is set to 30 in the experiments reported in the paper. However, we observed that increasing it beyond 10
34
+ does not bring substantial improvements.
35
+
36
+ ## Translate the data using standard decoding
37
+
38
+ Preprocess the input data:
39
+ ```
40
+ for LANG in $SRC_LANG $TGT_LANG; do
41
+ perl $MOSES_DECODER/scripts/tokenizer/tokenizer.perl -threads 80 -a -l $LANG < $INPUT.$LANG > $TMP/preprocessed.tok.$LANG
42
+ python $BPE_ROOT/apply_bpe.py -c ${BPE} < $TMP/preprocessed.tok.$LANG > $TMP/preprocessed.tok.bpe.$LANG
43
+ done
44
+ ```
45
+
46
+ Binarize the data for faster translation:
47
+
48
+ ```
49
+ fairseq-preprocess --srcdict $MODEL_DIR/dict.$SRC_LANG.txt --tgtdict $MODEL_DIR/dict.$TGT_LANG.txt
50
+ --source-lang ${SRC_LANG} --target-lang ${TGT_LANG} --testpref $TMP/preprocessed.tok.bpe --destdir $TMP/bin --workers 4
51
+ ```
52
+
53
+ Translate
54
+
55
+ ```
56
+ CUDA_VISIBLE_DEVICES=$GPU fairseq-generate $TMP/bin --path ${MODEL_DIR}/${SRC_LANG}-${TGT_LANG}.pt --beam 5
57
+ --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 > $TMP/fairseq.out
58
+ grep ^H $TMP/fairseq.out | cut -f3- > $TMP/mt.out
59
+ ```
60
+
61
+ Post-process
62
+
63
+ ```
64
+ sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/mt.out | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl
65
+ -l $TGT_LANG > $OUTPUT_DIR/mt.out
66
+ ```
67
+
68
+ ## Produce uncertainty estimates
69
+
70
+ ### Scoring
71
+
72
+ Make temporary files to store the translations repeated N times.
73
+
74
+ ```
75
+ python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/preprocessed.tok.bpe.$SRC_LANG -n $DROPOUT_N
76
+ -o $TMP/repeated.$SRC_LANG
77
+ python ${SCRIPTS}/scripts/uncertainty/repeat_lines.py -i $TMP/mt.out -n $DROPOUT_N -o $TMP/repeated.$TGT_LANG
78
+
79
+ fairseq-preprocess --srcdict ${MODEL_DIR}/dict.${SRC_LANG}.txt $TGT_DIC --source-lang ${SRC_LANG}
80
+ --target-lang ${TGT_LANG} --testpref ${TMP}/repeated --destdir ${TMP}/bin-repeated
81
+ ```
82
+
83
+ Produce model scores for the generated translations using `--retain-dropout` option to apply dropout at inference time:
84
+
85
+ ```
86
+ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt --beam 5
87
+ --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --unkpen 5 --score-reference --retain-dropout
88
+ --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder TransformerEncoderLayer
89
+ TransformerDecoderLayer --seed 46 > $TMP/dropout.scoring.out
90
+
91
+ grep ^H $TMP/dropout.scoring.out | cut -f2- > $TMP/dropout.scores
92
+
93
+ ```
94
+
95
+ Use `--retain-dropout-modules` to specify the modules. By default, dropout is applied in the same places
96
+ as for training.
97
+
98
+ Compute the mean of the resulting output distribution:
99
+
100
+ ```
101
+ python $SCRIPTS/scripts/uncertainty/aggregate_scores.py -i $TMP/dropout.scores -o $OUTPUT_DIR/dropout.scores.mean
102
+ -n $DROPOUT_N
103
+ ```
104
+
105
+ ### Generation
106
+
107
+ Produce multiple translation hypotheses for the same source using `--retain-dropout` option:
108
+
109
+ ```
110
+ CUDA_VISIBLE_DEVICES=${GPU} fairseq-generate ${TMP}/bin-repeated --path ${MODEL_DIR}/${LP}.pt
111
+ --beam 5 --source-lang $SRC_LANG --target-lang $TGT_LANG --no-progress-bar --retain-dropout
112
+ --unkpen 5 --retain-dropout-modules TransformerModel TransformerEncoder TransformerDecoder
113
+ TransformerEncoderLayer TransformerDecoderLayer --seed 46 > $TMP/dropout.generation.out
114
+
115
+ grep ^H $TMP/dropout.generation.out | cut -f3- > $TMP/dropout.hypotheses_
116
+
117
+ sed -r 's/(@@ )| (@@ ?$)//g' < $TMP/dropout.hypotheses_ | perl $MOSES_DECODER/scripts/tokenizer/detokenizer.perl
118
+ -l $TGT_LANG > $TMP/dropout.hypotheses
119
+ ```
120
+
121
+ Compute similarity between multiple hypotheses corresponding to the same source sentence using Meteor
122
+ evaluation metric:
123
+ ```
124
+ python meteor.py -i $TMP/dropout.hypotheses -m <path_to_meteor_installation> -n $DROPOUT_N -o
125
+ $OUTPUT_DIR/dropout.gen.sim.meteor
126
+ ```