Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/examples/camembert/README.md +75 -0
- fairseq-0.10.2/examples/joint_alignment_translation/README.md +89 -0
- fairseq-0.10.2/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh +118 -0
- fairseq-0.10.2/examples/language_model/README.adaptive_inputs.md +39 -0
- fairseq-0.10.2/examples/language_model/README.conv.md +40 -0
- fairseq-0.10.2/examples/language_model/README.md +123 -0
- fairseq-0.10.2/examples/language_model/prepare-wikitext-103.sh +33 -0
- fairseq-0.10.2/examples/linformer/README.md +22 -0
- fairseq-0.10.2/examples/linformer/linformer_src/__init__.py +6 -0
- fairseq-0.10.2/examples/linformer/linformer_src/models/__init__.py +0 -0
- fairseq-0.10.2/examples/linformer/linformer_src/models/linformer_roberta.py +134 -0
- fairseq-0.10.2/examples/linformer/linformer_src/modules/__init__.py +0 -0
- fairseq-0.10.2/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py +169 -0
- fairseq-0.10.2/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py +84 -0
- fairseq-0.10.2/examples/linformer/linformer_src/modules/multihead_linear_attention.py +485 -0
- fairseq-0.10.2/examples/multilingual/README.md +124 -0
- fairseq-0.10.2/examples/multilingual/finetune_multilingual_model.sh +27 -0
- fairseq-0.10.2/examples/multilingual/multilingual_fairseq_gen.sh +21 -0
- fairseq-0.10.2/examples/multilingual/train_multilingual_model.sh +23 -0
- fairseq-0.10.2/examples/nonautoregressive_translation/README.md +146 -0
- fairseq-0.10.2/examples/nonautoregressive_translation/scripts.md +179 -0
- fairseq-0.10.2/examples/pay_less_attention_paper/README.md +176 -0
- fairseq-0.10.2/examples/pointer_generator/README.md +82 -0
- fairseq-0.10.2/examples/pointer_generator/README.xsum.md +180 -0
- fairseq-0.10.2/examples/pointer_generator/pointer_generator_src/__init__.py +6 -0
- fairseq-0.10.2/examples/pointer_generator/pointer_generator_src/transformer_pg.py +468 -0
- fairseq-0.10.2/examples/pointer_generator/postprocess.py +96 -0
- fairseq-0.10.2/examples/pointer_generator/preprocess.py +102 -0
- fairseq-0.10.2/examples/roberta/README.custom_classification.md +168 -0
- fairseq-0.10.2/examples/roberta/commonsense_qa/README.md +99 -0
- fairseq-0.10.2/examples/roberta/commonsense_qa/__init__.py +6 -0
- fairseq-0.10.2/examples/roberta/commonsense_qa/commonsense_qa_task.py +190 -0
- fairseq-0.10.2/examples/roberta/commonsense_qa/download_cqa_data.sh +14 -0
- fairseq-0.10.2/examples/roberta/preprocess_RACE.py +102 -0
- fairseq-0.10.2/examples/roberta/wsc/README.md +125 -0
- fairseq-0.10.2/examples/roberta/wsc/__init__.py +7 -0
- fairseq-0.10.2/examples/roberta/wsc/wsc_criterion.py +167 -0
- fairseq-0.10.2/examples/roberta/wsc/wsc_task.py +401 -0
- fairseq-0.10.2/examples/roberta/wsc/wsc_utils.py +241 -0
- fairseq-0.10.2/examples/scaling_nmt/README.md +114 -0
- fairseq-0.10.2/examples/simultaneous_translation/criterions/__init__.py +15 -0
- fairseq-0.10.2/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py +73 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/agents/__init__.py +24 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/agents/agent.py +67 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/agents/simul_trans_agent.py +167 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py +81 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/scorers/__init__.py +19 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/scorers/scorer.py +175 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/scorers/text_scorer.py +41 -0
- fairseq-0.10.2/examples/speech_recognition/criterions/ASG_loss.py +170 -0
fairseq-0.10.2/examples/camembert/README.md
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CamemBERT: a Tasty French Language Model
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
[CamemBERT](https://arxiv.org/abs/1911.03894) is a pretrained language model trained on 138GB of French text based on RoBERTa.
|
| 6 |
+
|
| 7 |
+
Also available in [github.com/huggingface/transformers](https://github.com/huggingface/transformers/).
|
| 8 |
+
|
| 9 |
+
## Pre-trained models
|
| 10 |
+
|
| 11 |
+
| Model | #params | Download | Arch. | Training data |
|
| 12 |
+
|--------------------------------|---------|--------------------------------------------------------------------------------------------------------------------------|-------|-----------------------------------|
|
| 13 |
+
| `camembert` / `camembert-base` | 110M | [camembert-base.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz) | Base | OSCAR (138 GB of text) |
|
| 14 |
+
| `camembert-large` | 335M | [camembert-large.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-large.tar.gz) | Large | CCNet (135 GB of text) |
|
| 15 |
+
| `camembert-base-ccnet` | 110M | [camembert-base-ccnet.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet.tar.gz) | Base | CCNet (135 GB of text) |
|
| 16 |
+
| `camembert-base-wikipedia-4gb` | 110M | [camembert-base-wikipedia-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-wikipedia-4gb.tar.gz) | Base | Wikipedia (4 GB of text) |
|
| 17 |
+
| `camembert-base-oscar-4gb` | 110M | [camembert-base-oscar-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-oscar-4gb.tar.gz) | Base | Subsample of OSCAR (4 GB of text) |
|
| 18 |
+
| `camembert-base-ccnet-4gb` | 110M | [camembert-base-ccnet-4gb.tar.gz](https://dl.fbaipublicfiles.com/fairseq/models/camembert-base-ccnet-4gb.tar.gz) | Base | Subsample of CCNet (4 GB of text) |
|
| 19 |
+
|
| 20 |
+
## Example usage
|
| 21 |
+
|
| 22 |
+
### fairseq
|
| 23 |
+
##### Load CamemBERT from torch.hub (PyTorch >= 1.1):
|
| 24 |
+
```python
|
| 25 |
+
import torch
|
| 26 |
+
camembert = torch.hub.load('pytorch/fairseq', 'camembert')
|
| 27 |
+
camembert.eval() # disable dropout (or leave in train mode to finetune)
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
##### Load CamemBERT (for PyTorch 1.0 or custom models):
|
| 31 |
+
```python
|
| 32 |
+
# Download camembert model
|
| 33 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/camembert-base.tar.gz
|
| 34 |
+
tar -xzvf camembert.tar.gz
|
| 35 |
+
|
| 36 |
+
# Load the model in fairseq
|
| 37 |
+
from fairseq.models.roberta import CamembertModel
|
| 38 |
+
camembert = CamembertModel.from_pretrained('/path/to/camembert')
|
| 39 |
+
camembert.eval() # disable dropout (or leave in train mode to finetune)
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
##### Filling masks:
|
| 43 |
+
```python
|
| 44 |
+
masked_line = 'Le camembert est <mask> :)'
|
| 45 |
+
camembert.fill_mask(masked_line, topk=3)
|
| 46 |
+
# [('Le camembert est délicieux :)', 0.4909118115901947, ' délicieux'),
|
| 47 |
+
# ('Le camembert est excellent :)', 0.10556942224502563, ' excellent'),
|
| 48 |
+
# ('Le camembert est succulent :)', 0.03453322499990463, ' succulent')]
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
##### Extract features from Camembert:
|
| 52 |
+
```python
|
| 53 |
+
# Extract the last layer's features
|
| 54 |
+
line = "J'aime le camembert !"
|
| 55 |
+
tokens = camembert.encode(line)
|
| 56 |
+
last_layer_features = camembert.extract_features(tokens)
|
| 57 |
+
assert last_layer_features.size() == torch.Size([1, 10, 768])
|
| 58 |
+
|
| 59 |
+
# Extract all layer's features (layer 0 is the embedding layer)
|
| 60 |
+
all_layers = camembert.extract_features(tokens, return_all_hiddens=True)
|
| 61 |
+
assert len(all_layers) == 13
|
| 62 |
+
assert torch.all(all_layers[-1] == last_layer_features)
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
## Citation
|
| 66 |
+
If you use our work, please cite:
|
| 67 |
+
|
| 68 |
+
```bibtex
|
| 69 |
+
@inproceedings{martin2020camembert,
|
| 70 |
+
title={CamemBERT: a Tasty French Language Model},
|
| 71 |
+
author={Martin, Louis and Muller, Benjamin and Su{\'a}rez, Pedro Javier Ortiz and Dupont, Yoann and Romary, Laurent and de la Clergerie, {\'E}ric Villemonte and Seddah, Djam{\'e} and Sagot, Beno{\^\i}t},
|
| 72 |
+
booktitle={Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics},
|
| 73 |
+
year={2020}
|
| 74 |
+
}
|
| 75 |
+
```
|
fairseq-0.10.2/examples/joint_alignment_translation/README.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)
|
| 2 |
+
|
| 3 |
+
This page includes instructions for training models described in [Jointly Learning to Align and Translate with Transformer Models (Garg et al., 2019)](https://arxiv.org/abs/1909.02074).
|
| 4 |
+
|
| 5 |
+
## Training a joint alignment-translation model on WMT'18 En-De
|
| 6 |
+
|
| 7 |
+
##### 1. Extract and preprocess the WMT'18 En-De data
|
| 8 |
+
```bash
|
| 9 |
+
./prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
##### 2. Generate alignments from statistical alignment toolkits e.g. Giza++/FastAlign.
|
| 13 |
+
In this example, we use FastAlign.
|
| 14 |
+
```bash
|
| 15 |
+
git clone git@github.com:clab/fast_align.git
|
| 16 |
+
pushd fast_align
|
| 17 |
+
mkdir build
|
| 18 |
+
cd build
|
| 19 |
+
cmake ..
|
| 20 |
+
make
|
| 21 |
+
popd
|
| 22 |
+
ALIGN=fast_align/build/fast_align
|
| 23 |
+
paste bpe.32k/train.en bpe.32k/train.de | awk -F '\t' '{print $1 " ||| " $2}' > bpe.32k/train.en-de
|
| 24 |
+
$ALIGN -i bpe.32k/train.en-de -d -o -v > bpe.32k/train.align
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
##### 3. Preprocess the dataset with the above generated alignments.
|
| 28 |
+
```bash
|
| 29 |
+
fairseq-preprocess \
|
| 30 |
+
--source-lang en --target-lang de \
|
| 31 |
+
--trainpref bpe.32k/train \
|
| 32 |
+
--validpref bpe.32k/valid \
|
| 33 |
+
--testpref bpe.32k/test \
|
| 34 |
+
--align-suffix align \
|
| 35 |
+
--destdir binarized/ \
|
| 36 |
+
--joined-dictionary \
|
| 37 |
+
--workers 32
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
##### 4. Train a model
|
| 41 |
+
```bash
|
| 42 |
+
fairseq-train \
|
| 43 |
+
binarized \
|
| 44 |
+
--arch transformer_wmt_en_de_big_align --share-all-embeddings \
|
| 45 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 --activation-fn relu\
|
| 46 |
+
--lr 0.0002 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
|
| 47 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
| 48 |
+
--max-tokens 3500 --label-smoothing 0.1 \
|
| 49 |
+
--save-dir ./checkpoints --log-interval 1000 --max-update 60000 \
|
| 50 |
+
--keep-interval-updates -1 --save-interval-updates 0 \
|
| 51 |
+
--load-alignments --criterion label_smoothed_cross_entropy_with_alignment \
|
| 52 |
+
--fp16
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
|
| 56 |
+
|
| 57 |
+
If you want to train the above model with big batches (assuming your machine has 8 GPUs):
|
| 58 |
+
- add `--update-freq 8` to simulate training on 8x8=64 GPUs
|
| 59 |
+
- increase the learning rate; 0.0007 works well for big batches
|
| 60 |
+
|
| 61 |
+
##### 5. Evaluate and generate the alignments (BPE level)
|
| 62 |
+
```bash
|
| 63 |
+
fairseq-generate \
|
| 64 |
+
binarized --gen-subset test --print-alignment \
|
| 65 |
+
--source-lang en --target-lang de \
|
| 66 |
+
--path checkpoints/checkpoint_best.pt --beam 5 --nbest 1
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
##### 6. Other resources.
|
| 70 |
+
The code for:
|
| 71 |
+
1. preparing alignment test sets
|
| 72 |
+
2. converting BPE level alignments to token level alignments
|
| 73 |
+
3. symmetrizing bidirectional alignments
|
| 74 |
+
4. evaluating alignments using AER metric
|
| 75 |
+
can be found [here](https://github.com/lilt/alignment-scripts)
|
| 76 |
+
|
| 77 |
+
## Citation
|
| 78 |
+
|
| 79 |
+
```bibtex
|
| 80 |
+
@inproceedings{garg2019jointly,
|
| 81 |
+
title = {Jointly Learning to Align and Translate with Transformer Models},
|
| 82 |
+
author = {Garg, Sarthak and Peitz, Stephan and Nallasamy, Udhyakumar and Paulik, Matthias},
|
| 83 |
+
booktitle = {Conference on Empirical Methods in Natural Language Processing (EMNLP)},
|
| 84 |
+
address = {Hong Kong},
|
| 85 |
+
month = {November},
|
| 86 |
+
url = {https://arxiv.org/abs/1909.02074},
|
| 87 |
+
year = {2019},
|
| 88 |
+
}
|
| 89 |
+
```
|
fairseq-0.10.2/examples/joint_alignment_translation/prepare-wmt18en2de_no_norm_no_escape_no_agressive.sh
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the MIT license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
echo 'Cloning Moses github repository (for tokenization scripts)...'
|
| 9 |
+
git clone https://github.com/moses-smt/mosesdecoder.git
|
| 10 |
+
|
| 11 |
+
SCRIPTS=mosesdecoder/scripts
|
| 12 |
+
TOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
|
| 13 |
+
CLEAN=$SCRIPTS/training/clean-corpus-n.perl
|
| 14 |
+
REM_NON_PRINT_CHAR=$SCRIPTS/tokenizer/remove-non-printing-char.perl
|
| 15 |
+
|
| 16 |
+
URLS=(
|
| 17 |
+
"http://statmt.org/wmt13/training-parallel-europarl-v7.tgz"
|
| 18 |
+
"http://statmt.org/wmt13/training-parallel-commoncrawl.tgz"
|
| 19 |
+
"http://data.statmt.org/wmt18/translation-task/training-parallel-nc-v13.tgz"
|
| 20 |
+
"http://data.statmt.org/wmt18/translation-task/rapid2016.tgz"
|
| 21 |
+
"http://data.statmt.org/wmt17/translation-task/dev.tgz"
|
| 22 |
+
"http://statmt.org/wmt14/test-full.tgz"
|
| 23 |
+
)
|
| 24 |
+
CORPORA=(
|
| 25 |
+
"training/europarl-v7.de-en"
|
| 26 |
+
"commoncrawl.de-en"
|
| 27 |
+
"training-parallel-nc-v13/news-commentary-v13.de-en"
|
| 28 |
+
"rapid2016.de-en"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if [ ! -d "$SCRIPTS" ]; then
|
| 32 |
+
echo "Please set SCRIPTS variable correctly to point to Moses scripts."
|
| 33 |
+
exit
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
src=en
|
| 37 |
+
tgt=de
|
| 38 |
+
lang=en-de
|
| 39 |
+
prep=wmt18_en_de
|
| 40 |
+
tmp=$prep/tmp
|
| 41 |
+
orig=orig
|
| 42 |
+
dev=dev/newstest2012
|
| 43 |
+
codes=32000
|
| 44 |
+
bpe=bpe.32k
|
| 45 |
+
|
| 46 |
+
mkdir -p $orig $tmp $prep $bpe
|
| 47 |
+
|
| 48 |
+
cd $orig
|
| 49 |
+
|
| 50 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
| 51 |
+
url=${URLS[i]}
|
| 52 |
+
file=$(basename $url)
|
| 53 |
+
if [ -f $file ]; then
|
| 54 |
+
echo "$file already exists, skipping download"
|
| 55 |
+
else
|
| 56 |
+
wget "$url"
|
| 57 |
+
if [ -f $file ]; then
|
| 58 |
+
echo "$url successfully downloaded."
|
| 59 |
+
else
|
| 60 |
+
echo "$url not successfully downloaded."
|
| 61 |
+
exit 1
|
| 62 |
+
fi
|
| 63 |
+
if [ ${file: -4} == ".tgz" ]; then
|
| 64 |
+
tar zxvf $file
|
| 65 |
+
elif [ ${file: -4} == ".tar" ]; then
|
| 66 |
+
tar xvf $file
|
| 67 |
+
fi
|
| 68 |
+
fi
|
| 69 |
+
done
|
| 70 |
+
cd ..
|
| 71 |
+
|
| 72 |
+
echo "pre-processing train data..."
|
| 73 |
+
for l in $src $tgt; do
|
| 74 |
+
rm -rf $tmp/train.tags.$lang.tok.$l
|
| 75 |
+
for f in "${CORPORA[@]}"; do
|
| 76 |
+
cat $orig/$f.$l | \
|
| 77 |
+
perl $REM_NON_PRINT_CHAR | \
|
| 78 |
+
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/train.tags.$lang.tok.$l
|
| 79 |
+
done
|
| 80 |
+
done
|
| 81 |
+
|
| 82 |
+
echo "pre-processing test data..."
|
| 83 |
+
for l in $src $tgt; do
|
| 84 |
+
if [ "$l" == "$src" ]; then
|
| 85 |
+
t="src"
|
| 86 |
+
else
|
| 87 |
+
t="ref"
|
| 88 |
+
fi
|
| 89 |
+
grep '<seg id' $orig/test-full/newstest2014-deen-$t.$l.sgm | \
|
| 90 |
+
sed -e 's/<seg id="[0-9]*">\s*//g' | \
|
| 91 |
+
sed -e 's/\s*<\/seg>\s*//g' | \
|
| 92 |
+
sed -e "s/\’/\'/g" | \
|
| 93 |
+
perl $TOKENIZER -threads 8 -l $l -no-escape > $tmp/test.$l
|
| 94 |
+
echo ""
|
| 95 |
+
done
|
| 96 |
+
|
| 97 |
+
# apply length filtering before BPE
|
| 98 |
+
perl $CLEAN -ratio 1.5 $tmp/train.tags.$lang.tok $src $tgt $tmp/train 1 100
|
| 99 |
+
|
| 100 |
+
# use newstest2012 for valid
|
| 101 |
+
echo "pre-processing valid data..."
|
| 102 |
+
for l in $src $tgt; do
|
| 103 |
+
rm -rf $tmp/valid.$l
|
| 104 |
+
cat $orig/$dev.$l | \
|
| 105 |
+
perl $REM_NON_PRINT_CHAR | \
|
| 106 |
+
perl $TOKENIZER -threads 8 -l $l -no-escape >> $tmp/valid.$l
|
| 107 |
+
done
|
| 108 |
+
|
| 109 |
+
mkdir output
|
| 110 |
+
mv $tmp/{train,valid,test}.{$src,$tgt} output
|
| 111 |
+
|
| 112 |
+
#BPE
|
| 113 |
+
git clone https://github.com/glample/fastBPE.git
|
| 114 |
+
pushd fastBPE
|
| 115 |
+
g++ -std=c++11 -pthread -O3 fastBPE/main.cc -IfastBPE -o fast
|
| 116 |
+
popd
|
| 117 |
+
fastBPE/fast learnbpe $codes output/train.$src output/train.$tgt > $bpe/codes
|
| 118 |
+
for split in {train,valid,test}; do for lang in {en,de}; do fastBPE/fast applybpe $bpe/$split.$lang output/$split.$lang $bpe/codes; done; done
|
fairseq-0.10.2/examples/language_model/README.adaptive_inputs.md
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adaptive Input Representations for Neural Language Modeling (Baevski and Auli, 2018)
|
| 2 |
+
|
| 3 |
+
## Pre-trained models
|
| 4 |
+
|
| 5 |
+
Description | Parameters | Dataset | Model and Test set(s)
|
| 6 |
+
---|---:|---|---
|
| 7 |
+
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 1026M | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
|
| 8 |
+
Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) | 247M | [WikiText-103](https://blog.einstein.ai/the-wikitext-long-term-dependency-language-modeling-dataset/) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
|
| 9 |
+
|
| 10 |
+
## Training an LM with adaptive inputs
|
| 11 |
+
|
| 12 |
+
First, see the general [language modeling README](README.md) for instructions on
|
| 13 |
+
preprocessing the WikiText-103 data.
|
| 14 |
+
|
| 15 |
+
Then use the following training command to train a model with adaptive inputs
|
| 16 |
+
using the `transformer_lm_wiki103` model architecture:
|
| 17 |
+
```bash
|
| 18 |
+
fairseq-train --task language_modeling \
|
| 19 |
+
data-bin/wikitext-103 \
|
| 20 |
+
--save-dir checkpoints/transformer_wikitext-103 \
|
| 21 |
+
--arch transformer_lm_wiki103 \
|
| 22 |
+
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
|
| 23 |
+
--warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
|
| 24 |
+
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
|
| 25 |
+
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## Citation
|
| 29 |
+
|
| 30 |
+
```bibtex
|
| 31 |
+
@inproceedings{
|
| 32 |
+
baevski2018adaptive,
|
| 33 |
+
title={Adaptive Input Representations for Neural Language Modeling},
|
| 34 |
+
author={Alexei Baevski and Michael Auli},
|
| 35 |
+
booktitle={International Conference on Learning Representations},
|
| 36 |
+
year={2019},
|
| 37 |
+
url={https://openreview.net/forum?id=ByxZX20qFQ},
|
| 38 |
+
}
|
| 39 |
+
```
|
fairseq-0.10.2/examples/language_model/README.conv.md
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Language Modeling with Gated Convolutional Networks (Dauphin et al., 2017)
|
| 2 |
+
|
| 3 |
+
## Example usage
|
| 4 |
+
|
| 5 |
+
First download and preprocess the data following the main [language modeling README](README.md).
|
| 6 |
+
|
| 7 |
+
Then to train a convolutional LM using the `fconv_lm_dauphin_wikitext103`
|
| 8 |
+
architecture:
|
| 9 |
+
```bash
|
| 10 |
+
fairseq-train --task language_modeling \
|
| 11 |
+
data-bin/wikitext-103 \
|
| 12 |
+
--save-dir checkpoints/fconv_wikitext-103 \
|
| 13 |
+
--arch fconv_lm_dauphin_wikitext103 \
|
| 14 |
+
--adaptive-softmax-cutoff 10000,20000,200000 \
|
| 15 |
+
--dropout 0.2 \
|
| 16 |
+
--criterion adaptive_loss \
|
| 17 |
+
--optimizer nag --clip-norm 0.1 --weight-decay 5e-06 \
|
| 18 |
+
--lr 1.0 --lr-scheduler reduce_lr_on_plateau --lr-shrink 0.5 \
|
| 19 |
+
--max-tokens 1024 --tokens-per-sample 1024 \
|
| 20 |
+
--ddp-backend no_c10d \
|
| 21 |
+
--max-epoch 35
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
And evaluate with:
|
| 25 |
+
```bash
|
| 26 |
+
fairseq-eval-lm data-bin/wikitext-103 --path checkpoints/fconv_wiki103/checkpoint_best.pt
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## Citation
|
| 30 |
+
|
| 31 |
+
```bibtex
|
| 32 |
+
@inproceedings{dauphin2017language,
|
| 33 |
+
title={Language Modeling with Gated Convolutional Networks},
|
| 34 |
+
author={Dauphin, Yann N and Fan, Angela and Auli, Michael and Grangier, David},
|
| 35 |
+
booktitle={Proceedings of the 34th International Conference on Machine Learning-Volume 70},
|
| 36 |
+
pages={933--941},
|
| 37 |
+
year={2017},
|
| 38 |
+
organization={JMLR}
|
| 39 |
+
}
|
| 40 |
+
```
|
fairseq-0.10.2/examples/language_model/README.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Neural Language Modeling
|
| 2 |
+
|
| 3 |
+
## Pre-trained models
|
| 4 |
+
|
| 5 |
+
Model | Description | Dataset | Download
|
| 6 |
+
---|---|---|---
|
| 7 |
+
`transformer_lm.gbw.adaptive_huge` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 1026M params | [Google Billion Words](https://github.com/ciprian-chelba/1-billion-word-language-modeling-benchmark) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_gbw_huge.tar.bz2)
|
| 8 |
+
`transformer_lm.wiki103.adaptive` | Adaptive Inputs <br> ([Baevski and Auli, 2018](https://arxiv.org/abs/1809.10853)) <br> 247M params | [WikiText-103](https://einstein.ai/research/the-wikitext-long-term-dependency-language-modeling-dataset) | [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/lm/adaptive_lm_wiki103.v2.tar.bz2)
|
| 9 |
+
`transformer_lm.wmt19.en` | English LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.en.tar.gz)
|
| 10 |
+
`transformer_lm.wmt19.de` | German LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.de.tar.gz)
|
| 11 |
+
`transformer_lm.wmt19.ru` | Russian LM <br> ([Ng et al., 2019](https://arxiv.org/abs/1907.06616)) | [WMT News Crawl](http://data.statmt.org/news-crawl/) | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/lm/wmt19.ru.tar.gz)
|
| 12 |
+
|
| 13 |
+
## Example usage
|
| 14 |
+
|
| 15 |
+
We require a few additional Python dependencies for preprocessing:
|
| 16 |
+
```bash
|
| 17 |
+
pip install fastBPE sacremoses
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
To sample from a language model using PyTorch Hub:
|
| 21 |
+
```python
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
# List available models
|
| 25 |
+
torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...]
|
| 26 |
+
|
| 27 |
+
# Load an English LM trained on WMT'19 News Crawl data
|
| 28 |
+
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
|
| 29 |
+
en_lm.eval() # disable dropout
|
| 30 |
+
|
| 31 |
+
# Move model to GPU
|
| 32 |
+
en_lm.cuda()
|
| 33 |
+
|
| 34 |
+
# Sample from the language model
|
| 35 |
+
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)
|
| 36 |
+
# "Barack Obama is coming to Sydney and New Zealand (...)"
|
| 37 |
+
|
| 38 |
+
# Compute perplexity for a sequence
|
| 39 |
+
en_lm.score('Barack Obama is coming to Sydney and New Zealand')['positional_scores'].mean().neg().exp()
|
| 40 |
+
# tensor(15.1474)
|
| 41 |
+
|
| 42 |
+
# The same interface can be used with custom models as well
|
| 43 |
+
from fairseq.models.transformer_lm import TransformerLanguageModel
|
| 44 |
+
custom_lm = TransformerLanguageModel.from_pretrained('/path/to/model/dir', 'checkpoint100.pt', tokenizer='moses', bpe='fastbpe')
|
| 45 |
+
custom_lm.sample('Barack Obama', beam=5)
|
| 46 |
+
# "Barack Obama (...)"
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
## Training a transformer language model with the CLI tools
|
| 50 |
+
|
| 51 |
+
### 1) Preprocess the data
|
| 52 |
+
|
| 53 |
+
First download and prepare the [WikiText-103 dataset](https://www.salesforce.com/products/einstein/ai-research/the-wikitext-dependency-language-modeling-dataset/):
|
| 54 |
+
```bash
|
| 55 |
+
cd examples/language_model/
|
| 56 |
+
bash prepare-wikitext-103.sh
|
| 57 |
+
cd ../..
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
Next preprocess/binarize the data:
|
| 61 |
+
```bash
|
| 62 |
+
TEXT=examples/language_model/wikitext-103
|
| 63 |
+
fairseq-preprocess \
|
| 64 |
+
--only-source \
|
| 65 |
+
--trainpref $TEXT/wiki.train.tokens \
|
| 66 |
+
--validpref $TEXT/wiki.valid.tokens \
|
| 67 |
+
--testpref $TEXT/wiki.test.tokens \
|
| 68 |
+
--destdir data-bin/wikitext-103 \
|
| 69 |
+
--workers 20
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
### 2) Train a language model
|
| 73 |
+
|
| 74 |
+
Next we'll train a basic transformer language model on wikitext-103. For more
|
| 75 |
+
advanced usage, see the [adaptive inputs README](README.adaptive_inputs.md).
|
| 76 |
+
|
| 77 |
+
To train a basic LM (assumes 2 GPUs):
|
| 78 |
+
```
|
| 79 |
+
$ fairseq-train --task language_modeling \
|
| 80 |
+
data-bin/wikitext-103 \
|
| 81 |
+
--save-dir checkpoints/transformer_wikitext-103 \
|
| 82 |
+
--arch transformer_lm --share-decoder-input-output-embed \
|
| 83 |
+
--dropout 0.1 \
|
| 84 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --weight-decay 0.01 --clip-norm 0.0 \
|
| 85 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
|
| 86 |
+
--tokens-per-sample 512 --sample-break-mode none \
|
| 87 |
+
--max-tokens 2048 --update-freq 16 \
|
| 88 |
+
--fp16 \
|
| 89 |
+
--max-update 50000
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
If you run out of memory, try reducing `--max-tokens` (max number of tokens per
|
| 93 |
+
batch) or `--tokens-per-sample` (max sequence length). You can also adjust
|
| 94 |
+
`--update-freq` to accumulate gradients and simulate training on a different
|
| 95 |
+
number of GPUs.
|
| 96 |
+
|
| 97 |
+
### 3) Evaluate
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
fairseq-eval-lm data-bin/wikitext-103 \
|
| 101 |
+
--path checkpoints/transformer_wiki103/checkpoint_best.pt \
|
| 102 |
+
--batch-size 2 \
|
| 103 |
+
--tokens-per-sample 512 \
|
| 104 |
+
--context-window 400
|
| 105 |
+
# | Evaluated 245569 tokens in 56.1s (4379.02 tokens/s)
|
| 106 |
+
# | Loss: 3.4164, Perplexity: 30.46
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
*Note:* The `--context-window` option controls how much context is provided to
|
| 110 |
+
each token when computing perplexity. When the window size is 0, the dataset is
|
| 111 |
+
chunked into segments of length 512 and perplexity is computed over each segment
|
| 112 |
+
normally. However, this results in worse (higher) perplexity since tokens that
|
| 113 |
+
appear earlier in each segment have less conditioning. When the maximum window
|
| 114 |
+
size is used (511 in this case), then we compute perplexity for each token
|
| 115 |
+
fully conditioned on 511 tokens of context. This slows down evaluation
|
| 116 |
+
significantly, since we must run a separate forward pass for every token in the
|
| 117 |
+
dataset, but results in better (lower) perplexity.
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
## Convolutional language models
|
| 121 |
+
|
| 122 |
+
Please see the [convolutional LM README](README.conv.md) for instructions on
|
| 123 |
+
training convolutional language models.
|
fairseq-0.10.2/examples/language_model/prepare-wikitext-103.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Adapted from https://github.com/facebookresearch/MIXER/blob/master/prepareData.sh
|
| 3 |
+
|
| 4 |
+
URLS=(
|
| 5 |
+
"https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-103-v1.zip"
|
| 6 |
+
)
|
| 7 |
+
FILES=(
|
| 8 |
+
"wikitext-103-v1.zip"
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
for ((i=0;i<${#URLS[@]};++i)); do
|
| 12 |
+
file=${FILES[i]}
|
| 13 |
+
if [ -f $file ]; then
|
| 14 |
+
echo "$file already exists, skipping download"
|
| 15 |
+
else
|
| 16 |
+
url=${URLS[i]}
|
| 17 |
+
wget "$url"
|
| 18 |
+
if [ -f $file ]; then
|
| 19 |
+
echo "$url successfully downloaded."
|
| 20 |
+
else
|
| 21 |
+
echo "$url not successfully downloaded."
|
| 22 |
+
exit -1
|
| 23 |
+
fi
|
| 24 |
+
if [ ${file: -4} == ".tgz" ]; then
|
| 25 |
+
tar zxvf $file
|
| 26 |
+
elif [ ${file: -4} == ".tar" ]; then
|
| 27 |
+
tar xvf $file
|
| 28 |
+
elif [ ${file: -4} == ".zip" ]; then
|
| 29 |
+
unzip $file
|
| 30 |
+
fi
|
| 31 |
+
fi
|
| 32 |
+
done
|
| 33 |
+
cd ..
|
fairseq-0.10.2/examples/linformer/README.md
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Linformer: Self-Attention with Linear Complexity (Wang et al., 2020)
|
| 2 |
+
|
| 3 |
+
This example contains code to train Linformer models as described in our paper
|
| 4 |
+
[Linformer: Self-Attention with Linear Complexity](https://arxiv.org/abs/2006.04768).
|
| 5 |
+
|
| 6 |
+
## Training a new Linformer RoBERTa model
|
| 7 |
+
|
| 8 |
+
You can mostly follow the [RoBERTa pretraining README](/examples/roberta/README.pretraining.md),
|
| 9 |
+
updating your training command with `--user-dir examples/linformer/linformer_src --arch linformer_roberta_base`.
|
| 10 |
+
|
| 11 |
+
## Citation
|
| 12 |
+
|
| 13 |
+
If you use our work, please cite:
|
| 14 |
+
|
| 15 |
+
```bibtex
|
| 16 |
+
@article{wang2020linformer,
|
| 17 |
+
title={Linformer: Self-Attention with Linear Complexity},
|
| 18 |
+
author={Wang, Sinong and Li, Belinda and Khabsa, Madian and Fang, Han and Ma, Hao},
|
| 19 |
+
journal={arXiv preprint arXiv:2006.04768},
|
| 20 |
+
year={2020}
|
| 21 |
+
}
|
| 22 |
+
```
|
fairseq-0.10.2/examples/linformer/linformer_src/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .models import linformer_roberta # noqa
|
fairseq-0.10.2/examples/linformer/linformer_src/models/__init__.py
ADDED
|
File without changes
|
fairseq-0.10.2/examples/linformer/linformer_src/models/linformer_roberta.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
"""
|
| 6 |
+
Linformer: Self-Attention with Linear Complexity
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from fairseq.models import register_model, register_model_architecture
|
| 12 |
+
from fairseq.models.roberta import RobertaEncoder, RobertaModel
|
| 13 |
+
|
| 14 |
+
from ..modules.linformer_sentence_encoder import LinformerSentenceEncoder
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@register_model("linformer_roberta")
|
| 21 |
+
class LinformerModel(RobertaModel):
|
| 22 |
+
@staticmethod
|
| 23 |
+
def add_args(parser):
|
| 24 |
+
RobertaModel.add_args(parser)
|
| 25 |
+
|
| 26 |
+
# add args for Linformer
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--compressed", type=int, help="compressed ratio of sequence length"
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--shared-kv-compressed",
|
| 32 |
+
type=int,
|
| 33 |
+
help="share compressed matrix between k and v, in each layer",
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--shared-layer-kv-compressed",
|
| 37 |
+
type=int,
|
| 38 |
+
help="share compressed matrix between k and v and across all layers",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--freeze-compress",
|
| 42 |
+
type=int,
|
| 43 |
+
help="freeze the parameters in compressed layer",
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@classmethod
|
| 47 |
+
def build_model(cls, args, task):
|
| 48 |
+
"""Build a new model instance."""
|
| 49 |
+
|
| 50 |
+
# make sure all arguments are present
|
| 51 |
+
base_architecture(args)
|
| 52 |
+
|
| 53 |
+
if not hasattr(args, "max_positions"):
|
| 54 |
+
args.max_positions = args.tokens_per_sample
|
| 55 |
+
|
| 56 |
+
encoder = LinformerEncoder(args, task.source_dictionary)
|
| 57 |
+
return cls(args, encoder)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LinformerEncoder(RobertaEncoder):
|
| 61 |
+
"""Linformer encoder."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, args, dictionary):
|
| 64 |
+
super().__init__(args, dictionary)
|
| 65 |
+
|
| 66 |
+
self.sentence_encoder = LinformerSentenceEncoder(
|
| 67 |
+
padding_idx=dictionary.pad(),
|
| 68 |
+
vocab_size=len(dictionary),
|
| 69 |
+
num_encoder_layers=args.encoder_layers,
|
| 70 |
+
embedding_dim=args.encoder_embed_dim,
|
| 71 |
+
ffn_embedding_dim=args.encoder_ffn_embed_dim,
|
| 72 |
+
num_attention_heads=args.encoder_attention_heads,
|
| 73 |
+
dropout=args.dropout,
|
| 74 |
+
attention_dropout=args.attention_dropout,
|
| 75 |
+
activation_dropout=args.activation_dropout,
|
| 76 |
+
layerdrop=args.encoder_layerdrop,
|
| 77 |
+
max_seq_len=args.max_positions,
|
| 78 |
+
num_segments=0,
|
| 79 |
+
encoder_normalize_before=True,
|
| 80 |
+
apply_bert_init=True,
|
| 81 |
+
activation_fn=args.activation_fn,
|
| 82 |
+
q_noise=args.quant_noise_pq,
|
| 83 |
+
qn_block_size=args.quant_noise_pq_block_size,
|
| 84 |
+
compressed=args.compressed,
|
| 85 |
+
shared_kv_compressed=args.shared_kv_compressed,
|
| 86 |
+
shared_layer_kv_compressed=args.shared_layer_kv_compressed,
|
| 87 |
+
freeze_compress=args.freeze_compress,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@register_model_architecture("linformer_roberta", "linformer_roberta")
|
| 92 |
+
def base_architecture(args):
|
| 93 |
+
args.encoder_layers = getattr(args, "encoder_layers", 12)
|
| 94 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 768)
|
| 95 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 3072)
|
| 96 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 12)
|
| 97 |
+
|
| 98 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 99 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 100 |
+
|
| 101 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 102 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 103 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 104 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
| 105 |
+
args.encoder_layers_to_keep = getattr(args, "encoder_layers_to_keep", None)
|
| 106 |
+
args.encoder_layerdrop = getattr(args, "encoder_layerdrop", 0.0)
|
| 107 |
+
args.compressed = getattr(args, "compressed", 4)
|
| 108 |
+
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
| 109 |
+
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
| 110 |
+
args.freeze_compress = getattr(args, "freeze_compress", 0)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@register_model_architecture("linformer_roberta", "linformer_roberta_base")
|
| 114 |
+
def linformer_roberta_base_architecture(args):
|
| 115 |
+
base_architecture(args)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
@register_model_architecture("linformer_roberta", "linformer_roberta_large")
|
| 119 |
+
def linformer_roberta_large_architecture(args):
|
| 120 |
+
args.encoder_layers = getattr(args, "encoder_layers", 24)
|
| 121 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 122 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 123 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 124 |
+
|
| 125 |
+
args.activation_fn = getattr(args, "activation_fn", "gelu")
|
| 126 |
+
args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh")
|
| 127 |
+
|
| 128 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 129 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 130 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.0)
|
| 131 |
+
args.pooler_dropout = getattr(args, "pooler_dropout", 0.0)
|
| 132 |
+
args.compressed = getattr(args, "compressed", 4)
|
| 133 |
+
args.shared_kv_compressed = getattr(args, "shared_kv_compressed", 0)
|
| 134 |
+
args.shared_layer_kv_compressed = getattr(args, "shared_layer_kv_compressed", 0)
|
fairseq-0.10.2/examples/linformer/linformer_src/modules/__init__.py
ADDED
|
File without changes
|
fairseq-0.10.2/examples/linformer/linformer_src/modules/linformer_sentence_encoder.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from fairseq.modules import TransformerSentenceEncoder
|
| 10 |
+
|
| 11 |
+
from .linformer_sentence_encoder_layer import LinformerSentenceEncoderLayer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LinformerSentenceEncoder(TransformerSentenceEncoder):
|
| 15 |
+
"""
|
| 16 |
+
Implementation for a Bi-directional Linformer based Sentence Encoder used
|
| 17 |
+
in BERT/XLM style pre-trained models.
|
| 18 |
+
|
| 19 |
+
This first computes the token embedding using the token embedding matrix,
|
| 20 |
+
position embeddings (if specified) and segment embeddings
|
| 21 |
+
(if specified). After applying the specified number of
|
| 22 |
+
LinformerEncoderLayers, it outputs all the internal states of the
|
| 23 |
+
encoder as well as the final representation associated with the first
|
| 24 |
+
token (usually CLS token).
|
| 25 |
+
|
| 26 |
+
Input:
|
| 27 |
+
- tokens: B x T matrix representing sentences
|
| 28 |
+
- segment_labels: B x T matrix representing segment label for tokens
|
| 29 |
+
|
| 30 |
+
Output:
|
| 31 |
+
- a tuple of the following:
|
| 32 |
+
- a list of internal model states used to compute the
|
| 33 |
+
predictions where each tensor has shape T x B x C
|
| 34 |
+
- sentence representation associated with first input token
|
| 35 |
+
in format B x C.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
padding_idx: int,
|
| 41 |
+
vocab_size: int,
|
| 42 |
+
num_encoder_layers: int = 6,
|
| 43 |
+
embedding_dim: int = 768,
|
| 44 |
+
ffn_embedding_dim: int = 3072,
|
| 45 |
+
num_attention_heads: int = 8,
|
| 46 |
+
dropout: float = 0.1,
|
| 47 |
+
attention_dropout: float = 0.1,
|
| 48 |
+
activation_dropout: float = 0.1,
|
| 49 |
+
layerdrop: float = 0.0,
|
| 50 |
+
max_seq_len: int = 256,
|
| 51 |
+
num_segments: int = 2,
|
| 52 |
+
use_position_embeddings: bool = True,
|
| 53 |
+
offset_positions_by_padding: bool = True,
|
| 54 |
+
encoder_normalize_before: bool = False,
|
| 55 |
+
apply_bert_init: bool = False,
|
| 56 |
+
activation_fn: str = "relu",
|
| 57 |
+
learned_pos_embedding: bool = True,
|
| 58 |
+
embed_scale: float = None,
|
| 59 |
+
freeze_embeddings: bool = False,
|
| 60 |
+
n_trans_layers_to_freeze: int = 0,
|
| 61 |
+
export: bool = False,
|
| 62 |
+
traceable: bool = False,
|
| 63 |
+
q_noise: float = 0.0,
|
| 64 |
+
qn_block_size: int = 8,
|
| 65 |
+
compressed: int = 4,
|
| 66 |
+
shared_kv_compressed: int = 0,
|
| 67 |
+
shared_layer_kv_compressed: int = 0,
|
| 68 |
+
freeze_compress: int = 0,
|
| 69 |
+
) -> None:
|
| 70 |
+
|
| 71 |
+
# Initialize linformer parameters
|
| 72 |
+
self.compressed = compressed
|
| 73 |
+
self.shared_kv_compressed = shared_kv_compressed
|
| 74 |
+
self.shared_layer_kv_compressed = shared_layer_kv_compressed
|
| 75 |
+
self.compress_layer = None
|
| 76 |
+
self.freeze_compress = freeze_compress
|
| 77 |
+
|
| 78 |
+
super().__init__(
|
| 79 |
+
padding_idx=padding_idx,
|
| 80 |
+
vocab_size=vocab_size,
|
| 81 |
+
num_encoder_layers=num_encoder_layers,
|
| 82 |
+
embedding_dim=embedding_dim,
|
| 83 |
+
ffn_embedding_dim=ffn_embedding_dim,
|
| 84 |
+
num_attention_heads=num_attention_heads,
|
| 85 |
+
dropout=dropout,
|
| 86 |
+
attention_dropout=attention_dropout,
|
| 87 |
+
activation_dropout=activation_dropout,
|
| 88 |
+
layerdrop=layerdrop,
|
| 89 |
+
max_seq_len=max_seq_len,
|
| 90 |
+
num_segments=num_segments,
|
| 91 |
+
use_position_embeddings=use_position_embeddings,
|
| 92 |
+
offset_positions_by_padding=offset_positions_by_padding,
|
| 93 |
+
encoder_normalize_before=encoder_normalize_before,
|
| 94 |
+
apply_bert_init=apply_bert_init,
|
| 95 |
+
activation_fn=activation_fn,
|
| 96 |
+
learned_pos_embedding=learned_pos_embedding,
|
| 97 |
+
embed_scale=embed_scale,
|
| 98 |
+
freeze_embeddings=freeze_embeddings,
|
| 99 |
+
n_trans_layers_to_freeze=n_trans_layers_to_freeze,
|
| 100 |
+
export=export,
|
| 101 |
+
traceable=traceable,
|
| 102 |
+
q_noise=q_noise,
|
| 103 |
+
qn_block_size=qn_block_size,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def build_transformer_sentence_encoder_layer(
|
| 107 |
+
self,
|
| 108 |
+
embedding_dim,
|
| 109 |
+
ffn_embedding_dim,
|
| 110 |
+
num_attention_heads,
|
| 111 |
+
dropout,
|
| 112 |
+
attention_dropout,
|
| 113 |
+
activation_dropout,
|
| 114 |
+
activation_fn,
|
| 115 |
+
export,
|
| 116 |
+
q_noise,
|
| 117 |
+
qn_block_size,
|
| 118 |
+
):
|
| 119 |
+
if self.shared_layer_kv_compressed == 1:
|
| 120 |
+
compress_layer = nn.Linear(
|
| 121 |
+
self.max_seq_len, self.max_seq_len // self.compressed
|
| 122 |
+
)
|
| 123 |
+
# intialize parameters for compressed layer
|
| 124 |
+
nn.init.xavier_uniform_(compress_layer.weight, gain=1 / math.sqrt(2))
|
| 125 |
+
if self.freeze_compress == 1:
|
| 126 |
+
compress_layer.weight.requires_grad = False
|
| 127 |
+
self.compress_layer = compress_layer
|
| 128 |
+
|
| 129 |
+
return LinformerSentenceEncoderLayer(
|
| 130 |
+
embedding_dim=embedding_dim,
|
| 131 |
+
ffn_embedding_dim=ffn_embedding_dim,
|
| 132 |
+
num_attention_heads=num_attention_heads,
|
| 133 |
+
dropout=dropout,
|
| 134 |
+
attention_dropout=attention_dropout,
|
| 135 |
+
activation_dropout=activation_dropout,
|
| 136 |
+
activation_fn=activation_fn,
|
| 137 |
+
export=export,
|
| 138 |
+
q_noise=q_noise,
|
| 139 |
+
qn_block_size=qn_block_size,
|
| 140 |
+
compressed=self.compressed,
|
| 141 |
+
max_seq_len=self.max_seq_len,
|
| 142 |
+
shared_kv_compressed=self.shared_kv_compressed,
|
| 143 |
+
shared_compress_layer=(
|
| 144 |
+
None if self.shared_layer_kv_compressed == 0 else self.compress_layer
|
| 145 |
+
),
|
| 146 |
+
freeze_compress=self.freeze_compress,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 150 |
+
prefix = name + "." if name != "" else ""
|
| 151 |
+
items_to_add = {}
|
| 152 |
+
keys_to_remove = []
|
| 153 |
+
|
| 154 |
+
# update key name for shared layer in new version of code
|
| 155 |
+
for k in state_dict.keys():
|
| 156 |
+
if k.startswith(prefix + "compress_layer"):
|
| 157 |
+
if self.shared_layer_kv_compressed:
|
| 158 |
+
for layer_idx in range(len(self.layers)):
|
| 159 |
+
new_k = prefix + "layers.{0}.shared_compress_layer.{1}".format(
|
| 160 |
+
layer_idx,
|
| 161 |
+
k[len(prefix + "compress_layer.") :],
|
| 162 |
+
)
|
| 163 |
+
items_to_add[new_k] = state_dict[k]
|
| 164 |
+
|
| 165 |
+
for k in keys_to_remove:
|
| 166 |
+
del state_dict[k]
|
| 167 |
+
|
| 168 |
+
for key, value in items_to_add.items():
|
| 169 |
+
state_dict[key] = value
|
fairseq-0.10.2/examples/linformer/linformer_src/modules/linformer_sentence_encoder_layer.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Callable
|
| 7 |
+
|
| 8 |
+
from fairseq.modules import TransformerSentenceEncoderLayer
|
| 9 |
+
|
| 10 |
+
from .multihead_linear_attention import MultiheadLinearAttention
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class LinformerSentenceEncoderLayer(TransformerSentenceEncoderLayer):
|
| 14 |
+
"""
|
| 15 |
+
Implements a Linformer Encoder Layer used in BERT/XLM style pre-trained
|
| 16 |
+
models.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
embedding_dim: int = 768,
|
| 22 |
+
ffn_embedding_dim: int = 3072,
|
| 23 |
+
num_attention_heads: int = 8,
|
| 24 |
+
dropout: float = 0.1,
|
| 25 |
+
attention_dropout: float = 0.1,
|
| 26 |
+
activation_dropout: float = 0.1,
|
| 27 |
+
activation_fn: str = "relu",
|
| 28 |
+
export: bool = False,
|
| 29 |
+
q_noise: float = 0.0,
|
| 30 |
+
qn_block_size: int = 8,
|
| 31 |
+
init_fn: Callable = None,
|
| 32 |
+
compressed: int = 1,
|
| 33 |
+
max_seq_len: int = 256,
|
| 34 |
+
shared_kv_compressed: int = 0,
|
| 35 |
+
shared_compress_layer: any = None,
|
| 36 |
+
freeze_compress: int = 0,
|
| 37 |
+
) -> None:
|
| 38 |
+
|
| 39 |
+
# Initialize linformer parameters
|
| 40 |
+
self.compressed = compressed
|
| 41 |
+
self.max_seq_len = max_seq_len
|
| 42 |
+
self.shared_kv_compressed = shared_kv_compressed
|
| 43 |
+
self.freeze_compress = freeze_compress
|
| 44 |
+
|
| 45 |
+
def init_fn():
|
| 46 |
+
# This needs to be set after nn.Module.__init__ is called
|
| 47 |
+
self.shared_compress_layer = shared_compress_layer
|
| 48 |
+
|
| 49 |
+
super().__init__(
|
| 50 |
+
embedding_dim=embedding_dim,
|
| 51 |
+
ffn_embedding_dim=ffn_embedding_dim,
|
| 52 |
+
num_attention_heads=num_attention_heads,
|
| 53 |
+
dropout=dropout,
|
| 54 |
+
attention_dropout=attention_dropout,
|
| 55 |
+
activation_dropout=activation_dropout,
|
| 56 |
+
activation_fn=activation_fn,
|
| 57 |
+
export=export,
|
| 58 |
+
q_noise=q_noise,
|
| 59 |
+
qn_block_size=qn_block_size,
|
| 60 |
+
init_fn=init_fn,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def build_self_attention(
|
| 64 |
+
self,
|
| 65 |
+
embed_dim,
|
| 66 |
+
num_attention_heads,
|
| 67 |
+
dropout,
|
| 68 |
+
self_attention,
|
| 69 |
+
q_noise,
|
| 70 |
+
qn_block_size,
|
| 71 |
+
):
|
| 72 |
+
return MultiheadLinearAttention(
|
| 73 |
+
embed_dim,
|
| 74 |
+
num_attention_heads,
|
| 75 |
+
dropout=dropout,
|
| 76 |
+
self_attention=True,
|
| 77 |
+
q_noise=q_noise,
|
| 78 |
+
qn_block_size=qn_block_size,
|
| 79 |
+
compressed=self.compressed,
|
| 80 |
+
max_seq_len=self.max_seq_len,
|
| 81 |
+
shared_kv_compressed=self.shared_kv_compressed,
|
| 82 |
+
shared_compress_layer=self.shared_compress_layer,
|
| 83 |
+
freeze_compress=self.freeze_compress,
|
| 84 |
+
)
|
fairseq-0.10.2/examples/linformer/linformer_src/modules/multihead_linear_attention.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from typing import Dict, Optional, Tuple
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from fairseq import utils
|
| 12 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
| 13 |
+
from fairseq.modules.quant_noise import quant_noise
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
from torch.nn import Parameter
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@with_incremental_state
|
| 19 |
+
class MultiheadLinearAttention(nn.Module):
|
| 20 |
+
"""Multi-headed linformer attention.
|
| 21 |
+
|
| 22 |
+
Projects the key and values down to the compressed dimension, before computing self-attention.
|
| 23 |
+
|
| 24 |
+
See "Linformer: Self-Attention with Linear Complexity" for more details.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
embed_dim,
|
| 30 |
+
num_heads,
|
| 31 |
+
kdim=None,
|
| 32 |
+
vdim=None,
|
| 33 |
+
dropout=0.0,
|
| 34 |
+
bias=True,
|
| 35 |
+
add_bias_kv=False,
|
| 36 |
+
add_zero_attn=False,
|
| 37 |
+
self_attention=False,
|
| 38 |
+
encoder_decoder_attention=False,
|
| 39 |
+
q_noise=0.0,
|
| 40 |
+
qn_block_size=8,
|
| 41 |
+
compressed=1,
|
| 42 |
+
max_seq_len=256,
|
| 43 |
+
shared_kv_compressed=0,
|
| 44 |
+
shared_compress_layer=None,
|
| 45 |
+
freeze_compress=0,
|
| 46 |
+
):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.embed_dim = embed_dim
|
| 49 |
+
self.kdim = kdim if kdim is not None else embed_dim
|
| 50 |
+
self.vdim = vdim if vdim is not None else embed_dim
|
| 51 |
+
self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim
|
| 52 |
+
|
| 53 |
+
self.num_heads = num_heads
|
| 54 |
+
self.dropout = dropout
|
| 55 |
+
self.head_dim = embed_dim // num_heads
|
| 56 |
+
assert (
|
| 57 |
+
self.head_dim * num_heads == self.embed_dim
|
| 58 |
+
), "embed_dim must be divisible by num_heads"
|
| 59 |
+
self.scaling = self.head_dim ** -0.5
|
| 60 |
+
|
| 61 |
+
self.self_attention = self_attention
|
| 62 |
+
self.encoder_decoder_attention = encoder_decoder_attention
|
| 63 |
+
|
| 64 |
+
assert not self.self_attention or self.qkv_same_dim, (
|
| 65 |
+
"Self-attention requires query, key and " "value to be of the same size"
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
self.k_proj = quant_noise(
|
| 69 |
+
nn.Linear(self.kdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 70 |
+
)
|
| 71 |
+
self.v_proj = quant_noise(
|
| 72 |
+
nn.Linear(self.vdim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 73 |
+
)
|
| 74 |
+
self.q_proj = quant_noise(
|
| 75 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# used for compress sequence to subsequence
|
| 79 |
+
if shared_compress_layer is None:
|
| 80 |
+
self.compress_seq_len = max_seq_len // compressed
|
| 81 |
+
self.compress_k = nn.Linear(max_seq_len, self.compress_seq_len, bias=False)
|
| 82 |
+
if shared_kv_compressed == 0:
|
| 83 |
+
self.compress_v = nn.Linear(
|
| 84 |
+
max_seq_len, self.compress_seq_len, bias=False
|
| 85 |
+
)
|
| 86 |
+
self.layerwise_sharing = False
|
| 87 |
+
else:
|
| 88 |
+
self.compress_k = shared_compress_layer
|
| 89 |
+
if shared_kv_compressed == 0:
|
| 90 |
+
self.compress_v = shared_compress_layer
|
| 91 |
+
self.layerwise_sharing = True
|
| 92 |
+
self.shared_kv_compressed = shared_kv_compressed
|
| 93 |
+
|
| 94 |
+
self.out_proj = quant_noise(
|
| 95 |
+
nn.Linear(embed_dim, embed_dim, bias=bias), q_noise, qn_block_size
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
if add_bias_kv:
|
| 99 |
+
self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 100 |
+
self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
|
| 101 |
+
else:
|
| 102 |
+
self.bias_k = self.bias_v = None
|
| 103 |
+
|
| 104 |
+
self.add_zero_attn = add_zero_attn
|
| 105 |
+
|
| 106 |
+
self.reset_parameters()
|
| 107 |
+
|
| 108 |
+
if freeze_compress == 1:
|
| 109 |
+
self.compress_k.weight.requires_grad = False
|
| 110 |
+
if shared_kv_compressed == 0:
|
| 111 |
+
self.compress_v.weight.requires_grad = False
|
| 112 |
+
|
| 113 |
+
self.onnx_trace = False
|
| 114 |
+
self.tpu = False
|
| 115 |
+
|
| 116 |
+
def prepare_for_onnx_export_(self):
|
| 117 |
+
self.onnx_trace = True
|
| 118 |
+
|
| 119 |
+
def prepare_for_tpu_(self, **kwargs):
|
| 120 |
+
self.tpu = True
|
| 121 |
+
|
| 122 |
+
def reset_parameters(self):
|
| 123 |
+
if self.qkv_same_dim:
|
| 124 |
+
# Empirically observed the convergence to be much better with
|
| 125 |
+
# the scaled initialization
|
| 126 |
+
nn.init.xavier_uniform_(self.k_proj.weight, gain=1 / math.sqrt(2))
|
| 127 |
+
nn.init.xavier_uniform_(self.v_proj.weight, gain=1 / math.sqrt(2))
|
| 128 |
+
nn.init.xavier_uniform_(self.q_proj.weight, gain=1 / math.sqrt(2))
|
| 129 |
+
if (
|
| 130 |
+
not self.layerwise_sharing
|
| 131 |
+
): # otherwise, we already initialize the parameters
|
| 132 |
+
nn.init.xavier_uniform_(self.compress_k.weight, gain=1 / math.sqrt(2))
|
| 133 |
+
if self.shared_kv_compressed == 0:
|
| 134 |
+
nn.init.xavier_uniform_(
|
| 135 |
+
self.compress_v.weight, gain=1 / math.sqrt(2)
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
nn.init.xavier_uniform_(self.k_proj.weight)
|
| 139 |
+
nn.init.xavier_uniform_(self.v_proj.weight)
|
| 140 |
+
nn.init.xavier_uniform_(self.q_proj.weight)
|
| 141 |
+
if (
|
| 142 |
+
not self.layerwise_sharing
|
| 143 |
+
): # otherwise, we already initialize the parameters
|
| 144 |
+
nn.init.xavier_uniform_(self.compress_k.weight)
|
| 145 |
+
if self.shared_kv_compressed == 0:
|
| 146 |
+
nn.init.xavier_uniform_(self.compress_v.weight)
|
| 147 |
+
|
| 148 |
+
nn.init.xavier_uniform_(self.out_proj.weight)
|
| 149 |
+
if self.out_proj.bias is not None:
|
| 150 |
+
nn.init.constant_(self.out_proj.bias, 0.0)
|
| 151 |
+
if self.bias_k is not None:
|
| 152 |
+
nn.init.xavier_normal_(self.bias_k)
|
| 153 |
+
if self.bias_v is not None:
|
| 154 |
+
nn.init.xavier_normal_(self.bias_v)
|
| 155 |
+
|
| 156 |
+
def forward(
|
| 157 |
+
self,
|
| 158 |
+
query,
|
| 159 |
+
key: Optional[Tensor],
|
| 160 |
+
value: Optional[Tensor],
|
| 161 |
+
key_padding_mask: Optional[Tensor] = None,
|
| 162 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 163 |
+
need_weights: bool = True,
|
| 164 |
+
static_kv: bool = False,
|
| 165 |
+
attn_mask: Optional[Tensor] = None,
|
| 166 |
+
before_softmax: bool = False,
|
| 167 |
+
need_head_weights: bool = False,
|
| 168 |
+
) -> Tuple[Tensor, Optional[Tensor]]:
|
| 169 |
+
"""Input shape: Time x Batch x Channel
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
key_padding_mask (ByteTensor, optional): mask to exclude
|
| 173 |
+
keys that are pads, of shape `(batch, src_len)`, where
|
| 174 |
+
padding elements are indicated by 1s.
|
| 175 |
+
need_weights (bool, optional): return the attention weights,
|
| 176 |
+
averaged over heads (default: False).
|
| 177 |
+
attn_mask (ByteTensor, optional): typically used to
|
| 178 |
+
implement causal attention, where the mask prevents the
|
| 179 |
+
attention from looking forward in time (default: None).
|
| 180 |
+
before_softmax (bool, optional): return the raw attention
|
| 181 |
+
weights and values before the attention softmax.
|
| 182 |
+
need_head_weights (bool, optional): return the attention
|
| 183 |
+
weights for each head. Implies *need_weights*. Default:
|
| 184 |
+
return the average attention weights over all heads.
|
| 185 |
+
"""
|
| 186 |
+
if need_head_weights:
|
| 187 |
+
need_weights = True
|
| 188 |
+
|
| 189 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 190 |
+
assert embed_dim == self.embed_dim
|
| 191 |
+
assert list(query.size()) == [tgt_len, bsz, embed_dim]
|
| 192 |
+
|
| 193 |
+
if incremental_state is not None:
|
| 194 |
+
saved_state = self._get_input_buffer(incremental_state)
|
| 195 |
+
if saved_state is not None and "prev_key" in saved_state:
|
| 196 |
+
# previous time steps are cached - no need to recompute
|
| 197 |
+
# key and value if they are static
|
| 198 |
+
if static_kv:
|
| 199 |
+
assert self.encoder_decoder_attention and not self.self_attention
|
| 200 |
+
key = value = None
|
| 201 |
+
else:
|
| 202 |
+
saved_state = None
|
| 203 |
+
|
| 204 |
+
if self.self_attention:
|
| 205 |
+
q = self.q_proj(query)
|
| 206 |
+
|
| 207 |
+
k_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
| 208 |
+
k_input = (
|
| 209 |
+
F.linear(k_input, self.compress_k.weight[:, 0:tgt_len])
|
| 210 |
+
.permute(2, 0, 1)
|
| 211 |
+
.contiguous()
|
| 212 |
+
)
|
| 213 |
+
k = self.k_proj(k_input)
|
| 214 |
+
|
| 215 |
+
v_input = query.permute(1, 2, 0).contiguous() # B * C * T
|
| 216 |
+
if self.shared_kv_compressed == 0:
|
| 217 |
+
v_input = (
|
| 218 |
+
F.linear(v_input, self.compress_v.weight[:, 0:tgt_len])
|
| 219 |
+
.permute(2, 0, 1)
|
| 220 |
+
.contiguous()
|
| 221 |
+
)
|
| 222 |
+
if self.shared_kv_compressed == 1: # use shared kv compressed linear layer
|
| 223 |
+
v_input = (
|
| 224 |
+
F.linear(v_input, self.compress_k.weight[:, 0:tgt_len])
|
| 225 |
+
.permute(2, 0, 1)
|
| 226 |
+
.contiguous()
|
| 227 |
+
)
|
| 228 |
+
v = self.v_proj(v_input)
|
| 229 |
+
elif self.encoder_decoder_attention:
|
| 230 |
+
# encoder-decoder attention
|
| 231 |
+
q = self.q_proj(query)
|
| 232 |
+
if key is None:
|
| 233 |
+
assert value is None
|
| 234 |
+
k = v = None
|
| 235 |
+
else:
|
| 236 |
+
k = self.k_proj(key)
|
| 237 |
+
v = self.v_proj(key)
|
| 238 |
+
|
| 239 |
+
else:
|
| 240 |
+
assert key is not None and value is not None
|
| 241 |
+
q = self.q_proj(query)
|
| 242 |
+
k = self.k_proj(key)
|
| 243 |
+
v = self.v_proj(value)
|
| 244 |
+
q *= self.scaling
|
| 245 |
+
|
| 246 |
+
if self.bias_k is not None:
|
| 247 |
+
assert self.bias_v is not None
|
| 248 |
+
k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)])
|
| 249 |
+
v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)])
|
| 250 |
+
if attn_mask is not None:
|
| 251 |
+
attn_mask = torch.cat(
|
| 252 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 253 |
+
)
|
| 254 |
+
if key_padding_mask is not None:
|
| 255 |
+
key_padding_mask = torch.cat(
|
| 256 |
+
[
|
| 257 |
+
key_padding_mask,
|
| 258 |
+
key_padding_mask.new_zeros(key_padding_mask.size(0), 1),
|
| 259 |
+
],
|
| 260 |
+
dim=1,
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
q = (
|
| 264 |
+
q.contiguous()
|
| 265 |
+
.view(tgt_len, bsz * self.num_heads, self.head_dim)
|
| 266 |
+
.transpose(0, 1)
|
| 267 |
+
)
|
| 268 |
+
if k is not None:
|
| 269 |
+
k = (
|
| 270 |
+
k.contiguous()
|
| 271 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 272 |
+
.transpose(0, 1)
|
| 273 |
+
)
|
| 274 |
+
if v is not None:
|
| 275 |
+
v = (
|
| 276 |
+
v.contiguous()
|
| 277 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 278 |
+
.transpose(0, 1)
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
if saved_state is not None:
|
| 282 |
+
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
|
| 283 |
+
if "prev_key" in saved_state:
|
| 284 |
+
_prev_key = saved_state["prev_key"]
|
| 285 |
+
assert _prev_key is not None
|
| 286 |
+
prev_key = _prev_key.view(bsz * self.num_heads, -1, self.head_dim)
|
| 287 |
+
if static_kv:
|
| 288 |
+
k = prev_key
|
| 289 |
+
else:
|
| 290 |
+
assert k is not None
|
| 291 |
+
k = torch.cat([prev_key, k], dim=1)
|
| 292 |
+
if "prev_value" in saved_state:
|
| 293 |
+
_prev_value = saved_state["prev_value"]
|
| 294 |
+
assert _prev_value is not None
|
| 295 |
+
prev_value = _prev_value.view(bsz * self.num_heads, -1, self.head_dim)
|
| 296 |
+
if static_kv:
|
| 297 |
+
v = prev_value
|
| 298 |
+
else:
|
| 299 |
+
assert v is not None
|
| 300 |
+
v = torch.cat([prev_value, v], dim=1)
|
| 301 |
+
prev_key_padding_mask: Optional[Tensor] = None
|
| 302 |
+
if "prev_key_padding_mask" in saved_state:
|
| 303 |
+
prev_key_padding_mask = saved_state["prev_key_padding_mask"]
|
| 304 |
+
assert k is not None and v is not None
|
| 305 |
+
key_padding_mask = MultiheadLinearAttention._append_prev_key_padding_mask(
|
| 306 |
+
key_padding_mask=key_padding_mask,
|
| 307 |
+
prev_key_padding_mask=prev_key_padding_mask,
|
| 308 |
+
batch_size=bsz,
|
| 309 |
+
src_len=k.size(1),
|
| 310 |
+
static_kv=static_kv,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
saved_state["prev_key"] = k.view(bsz, self.num_heads, -1, self.head_dim)
|
| 314 |
+
saved_state["prev_value"] = v.view(bsz, self.num_heads, -1, self.head_dim)
|
| 315 |
+
saved_state["prev_key_padding_mask"] = key_padding_mask
|
| 316 |
+
# In this branch incremental_state is never None
|
| 317 |
+
assert incremental_state is not None
|
| 318 |
+
incremental_state = self._set_input_buffer(incremental_state, saved_state)
|
| 319 |
+
assert k is not None
|
| 320 |
+
src_len = k.size(1)
|
| 321 |
+
|
| 322 |
+
if self.add_zero_attn:
|
| 323 |
+
assert v is not None
|
| 324 |
+
src_len += 1
|
| 325 |
+
k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1)
|
| 326 |
+
v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1)
|
| 327 |
+
if attn_mask is not None:
|
| 328 |
+
attn_mask = torch.cat(
|
| 329 |
+
[attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
attn_weights = torch.bmm(q, k.transpose(1, 2))
|
| 333 |
+
attn_weights = MultiheadLinearAttention.apply_sparse_mask(
|
| 334 |
+
attn_weights, tgt_len, src_len, bsz
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
| 338 |
+
|
| 339 |
+
if attn_mask is not None:
|
| 340 |
+
attn_mask = attn_mask.unsqueeze(0)
|
| 341 |
+
if self.onnx_trace:
|
| 342 |
+
attn_mask = attn_mask.repeat(attn_weights.size(0), 1, 1)
|
| 343 |
+
attn_weights += attn_mask
|
| 344 |
+
|
| 345 |
+
if before_softmax:
|
| 346 |
+
return attn_weights, v
|
| 347 |
+
|
| 348 |
+
attn_weights_float = utils.softmax(
|
| 349 |
+
attn_weights, dim=-1, onnx_trace=self.onnx_trace
|
| 350 |
+
)
|
| 351 |
+
attn_weights = attn_weights_float.type_as(attn_weights)
|
| 352 |
+
attn_probs = F.dropout(
|
| 353 |
+
attn_weights,
|
| 354 |
+
p=self.dropout,
|
| 355 |
+
training=self.training,
|
| 356 |
+
)
|
| 357 |
+
assert v is not None
|
| 358 |
+
attn = torch.bmm(attn_probs, v)
|
| 359 |
+
assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim]
|
| 360 |
+
if self.onnx_trace and attn.size(1) == 1:
|
| 361 |
+
# when ONNX tracing a single decoder step (sequence length == 1)
|
| 362 |
+
# the transpose is a no-op copy before view, thus unnecessary
|
| 363 |
+
attn = attn.contiguous().view(tgt_len, bsz, embed_dim)
|
| 364 |
+
else:
|
| 365 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 366 |
+
attn = self.out_proj(attn)
|
| 367 |
+
attn_weights: Optional[Tensor] = None
|
| 368 |
+
if need_weights:
|
| 369 |
+
attn_weights = attn_weights_float.view(
|
| 370 |
+
bsz, self.num_heads, tgt_len, src_len
|
| 371 |
+
).transpose(1, 0)
|
| 372 |
+
if not need_head_weights:
|
| 373 |
+
# average attention weights over heads
|
| 374 |
+
attn_weights = attn_weights.mean(dim=0)
|
| 375 |
+
|
| 376 |
+
return attn, attn_weights
|
| 377 |
+
|
| 378 |
+
@staticmethod
|
| 379 |
+
def _append_prev_key_padding_mask(
|
| 380 |
+
key_padding_mask: Optional[Tensor],
|
| 381 |
+
prev_key_padding_mask: Optional[Tensor],
|
| 382 |
+
batch_size: int,
|
| 383 |
+
src_len: int,
|
| 384 |
+
static_kv: bool,
|
| 385 |
+
) -> Optional[Tensor]:
|
| 386 |
+
# saved key padding masks have shape (bsz, seq_len)
|
| 387 |
+
if prev_key_padding_mask is not None and static_kv:
|
| 388 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 389 |
+
elif prev_key_padding_mask is not None and key_padding_mask is not None:
|
| 390 |
+
new_key_padding_mask = torch.cat(
|
| 391 |
+
[prev_key_padding_mask.float(), key_padding_mask.float()], dim=1
|
| 392 |
+
)
|
| 393 |
+
# During incremental decoding, as the padding token enters and
|
| 394 |
+
# leaves the frame, there will be a time when prev or current
|
| 395 |
+
# is None
|
| 396 |
+
elif prev_key_padding_mask is not None:
|
| 397 |
+
filler = torch.zeros(
|
| 398 |
+
(batch_size, src_len - prev_key_padding_mask.size(1)),
|
| 399 |
+
device=prev_key_padding_mask.device,
|
| 400 |
+
)
|
| 401 |
+
new_key_padding_mask = torch.cat(
|
| 402 |
+
[prev_key_padding_mask.float(), filler.float()], dim=1
|
| 403 |
+
)
|
| 404 |
+
elif key_padding_mask is not None:
|
| 405 |
+
filler = torch.zeros(
|
| 406 |
+
(batch_size, src_len - key_padding_mask.size(1)),
|
| 407 |
+
device=key_padding_mask.device,
|
| 408 |
+
)
|
| 409 |
+
new_key_padding_mask = torch.cat(
|
| 410 |
+
[filler.float(), key_padding_mask.float()], dim=1
|
| 411 |
+
)
|
| 412 |
+
else:
|
| 413 |
+
new_key_padding_mask = prev_key_padding_mask
|
| 414 |
+
return new_key_padding_mask
|
| 415 |
+
|
| 416 |
+
@torch.jit.export
|
| 417 |
+
def reorder_incremental_state(
|
| 418 |
+
self,
|
| 419 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 420 |
+
new_order: Tensor,
|
| 421 |
+
):
|
| 422 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
| 423 |
+
input_buffer = self._get_input_buffer(incremental_state)
|
| 424 |
+
if input_buffer is not None:
|
| 425 |
+
for k in input_buffer.keys():
|
| 426 |
+
input_buffer_k = input_buffer[k]
|
| 427 |
+
if input_buffer_k is not None:
|
| 428 |
+
if self.encoder_decoder_attention and input_buffer_k.size(
|
| 429 |
+
0
|
| 430 |
+
) == new_order.size(0):
|
| 431 |
+
break
|
| 432 |
+
input_buffer[k] = input_buffer_k.index_select(0, new_order)
|
| 433 |
+
incremental_state = self._set_input_buffer(incremental_state, input_buffer)
|
| 434 |
+
return incremental_state
|
| 435 |
+
|
| 436 |
+
def _get_input_buffer(
|
| 437 |
+
self, incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]]
|
| 438 |
+
) -> Dict[str, Optional[Tensor]]:
|
| 439 |
+
result = self.get_incremental_state(incremental_state, "attn_state")
|
| 440 |
+
if result is not None:
|
| 441 |
+
return result
|
| 442 |
+
else:
|
| 443 |
+
empty_result: Dict[str, Optional[Tensor]] = {}
|
| 444 |
+
return empty_result
|
| 445 |
+
|
| 446 |
+
def _set_input_buffer(
|
| 447 |
+
self,
|
| 448 |
+
incremental_state: Dict[str, Dict[str, Optional[Tensor]]],
|
| 449 |
+
buffer: Dict[str, Optional[Tensor]],
|
| 450 |
+
):
|
| 451 |
+
return self.set_incremental_state(incremental_state, "attn_state", buffer)
|
| 452 |
+
|
| 453 |
+
def apply_sparse_mask(attn_weights, tgt_len: int, src_len: int, bsz: int):
|
| 454 |
+
return attn_weights
|
| 455 |
+
|
| 456 |
+
def upgrade_state_dict_named(self, state_dict, name):
|
| 457 |
+
prefix = name + "." if name != "" else ""
|
| 458 |
+
items_to_add = {}
|
| 459 |
+
keys_to_remove = []
|
| 460 |
+
for k in state_dict.keys():
|
| 461 |
+
if k.endswith(prefix + "in_proj_weight"):
|
| 462 |
+
# in_proj_weight used to be q + k + v with same dimensions
|
| 463 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 464 |
+
items_to_add[prefix + "q_proj.weight"] = state_dict[k][:dim]
|
| 465 |
+
items_to_add[prefix + "k_proj.weight"] = state_dict[k][dim : 2 * dim]
|
| 466 |
+
items_to_add[prefix + "v_proj.weight"] = state_dict[k][2 * dim :]
|
| 467 |
+
|
| 468 |
+
keys_to_remove.append(k)
|
| 469 |
+
|
| 470 |
+
k_bias = prefix + "in_proj_bias"
|
| 471 |
+
if k_bias in state_dict.keys():
|
| 472 |
+
dim = int(state_dict[k].shape[0] / 3)
|
| 473 |
+
items_to_add[prefix + "q_proj.bias"] = state_dict[k_bias][:dim]
|
| 474 |
+
items_to_add[prefix + "k_proj.bias"] = state_dict[k_bias][
|
| 475 |
+
dim : 2 * dim
|
| 476 |
+
]
|
| 477 |
+
items_to_add[prefix + "v_proj.bias"] = state_dict[k_bias][2 * dim :]
|
| 478 |
+
|
| 479 |
+
keys_to_remove.append(prefix + "in_proj_bias")
|
| 480 |
+
|
| 481 |
+
for k in keys_to_remove:
|
| 482 |
+
del state_dict[k]
|
| 483 |
+
|
| 484 |
+
for key, value in items_to_add.items():
|
| 485 |
+
state_dict[key] = value
|
fairseq-0.10.2/examples/multilingual/README.md
ADDED
|
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multilingual Translation
|
| 2 |
+
|
| 3 |
+
[[Multilingual Translation with Extensible Multilingual Pretraining and Finetuning, https://arxiv.org/abs/2008.00401]](https://arxiv.org/abs/2008.00401)
|
| 4 |
+
|
| 5 |
+
## Introduction
|
| 6 |
+
|
| 7 |
+
This work is for training multilingual translation models with multiple bitext datasets. This multilingual translation framework supports (see [[training section]](#Training) and [[finetuning section]](#Finetuning) for examples)
|
| 8 |
+
|
| 9 |
+
* temperature based sampling over unbalancing datasets of different translation directions
|
| 10 |
+
- --sampling-method' with
|
| 11 |
+
choices=['uniform', 'temperature', 'concat']
|
| 12 |
+
- --sampling-temperature
|
| 13 |
+
* configurable to automatically add source and/or target language tokens to source/target sentences using data which are prepared in the same way as bilignual training
|
| 14 |
+
- --encoder-langtok with choices=['src', 'tgt', None] to specify whether to add source or target language tokens to the source sentences
|
| 15 |
+
- --decoder-langtok (binary option) to specify whether to add target language tokens to the target sentences or not
|
| 16 |
+
* finetuning mBART pretrained models for multilingual translation
|
| 17 |
+
- --finetune-from-model to specify the path from which to load the pretrained model
|
| 18 |
+
|
| 19 |
+
## Preprocessing data
|
| 20 |
+
Multilingual training requires a joint BPE vocab. Please follow [mBART's preprocessing steps](https://github.com/pytorch/fairseq/tree/master/examples/mbart#bpe-data) to reuse our pretrained sentence-piece model.
|
| 21 |
+
|
| 22 |
+
You can also train a joint BPE model on your own dataset and then follow the steps in [[link]](https://github.com/pytorch/fairseq/tree/master/examples/translation#multilingual-translation).
|
| 23 |
+
|
| 24 |
+
## Training
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
lang_pairs=<language pairs to be trained, e.g. "en-cs,cs-en">
|
| 29 |
+
path_2_data=<set to data path>
|
| 30 |
+
lang_list=<a file which contains a list of languages separated by new lines>
|
| 31 |
+
|
| 32 |
+
fairseq-train $path_2_data \
|
| 33 |
+
--encoder-normalize-before --decoder-normalize-before \
|
| 34 |
+
--arch transformer --layernorm-embedding \
|
| 35 |
+
--task translation_multi_simple_epoch \
|
| 36 |
+
--sampling-method "temperature" \
|
| 37 |
+
--sampling-temperature 1.5 \
|
| 38 |
+
--encoder-langtok "src" \
|
| 39 |
+
--decoder-langtok \
|
| 40 |
+
--lang-dict "$lang_list" \
|
| 41 |
+
--lang-pairs "$lang_pairs" \
|
| 42 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
|
| 43 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
| 44 |
+
--lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \
|
| 45 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
| 46 |
+
--max-tokens 1024 --update-freq 2 \
|
| 47 |
+
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
|
| 48 |
+
--seed 222 --log-format simple --log-interval 2
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Finetuning
|
| 52 |
+
We can also finetune multilingual models from a monolingual pretrained models, e.g. [mMBART](https://github.com/pytorch/fairseq/tree/master/examples/mbart).
|
| 53 |
+
```bash
|
| 54 |
+
lang_pairs=<language pairs to be trained, e.g. "en-cs,cs-en">
|
| 55 |
+
path_2_data=<set to data path>
|
| 56 |
+
lang_list=<a file which contains a list of languages separated by new lines>
|
| 57 |
+
pretrained_model=<path to the pretrained model, e.g. mbart or another trained multilingual model>
|
| 58 |
+
|
| 59 |
+
fairseq-train $path_2_data \
|
| 60 |
+
--finetune-from-model $pretrained_model \
|
| 61 |
+
--encoder-normalize-before --decoder-normalize-before \
|
| 62 |
+
--arch transformer --layernorm-embedding \
|
| 63 |
+
--task translation_multi_simple_epoch \
|
| 64 |
+
--sampling-method "temperature" \
|
| 65 |
+
--sampling-temperature 1.5 \
|
| 66 |
+
--encoder-langtok "src" \
|
| 67 |
+
--decoder-langtok \
|
| 68 |
+
--lang-dict "$lang_list" \
|
| 69 |
+
--lang-pairs "$lang_pairs" \
|
| 70 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
|
| 71 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
| 72 |
+
--lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \
|
| 73 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
| 74 |
+
--max-tokens 1024 --update-freq 2 \
|
| 75 |
+
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
|
| 76 |
+
--seed 222 --log-format simple --log-interval 2
|
| 77 |
+
```
|
| 78 |
+
## Generate
|
| 79 |
+
The following command uses the multilingual task (translation_multi_simple_epoch) to generate translation from $source_lang to $target_lang on the test dataset. During generaton, the source language tokens are added to source sentences and the target language tokens are added as the starting token to decode target sentences. Options --lang-dict and --lang-pairs are needed to tell the generation process the ordered list of languages and translation directions that the trained model are awared of; they will need to be consistent with the training.
|
| 80 |
+
|
| 81 |
+
```bash
|
| 82 |
+
model=<multilingual model>
|
| 83 |
+
source_lang=<source language>
|
| 84 |
+
target_lang=<target language>
|
| 85 |
+
|
| 86 |
+
fairseq-generate $path_2_data \
|
| 87 |
+
--path $model \
|
| 88 |
+
--task translation_multi_simple_epoch \
|
| 89 |
+
--gen-subset test \
|
| 90 |
+
--source-lang $source_lang \
|
| 91 |
+
--target-lang $target_lang
|
| 92 |
+
--sacrebleu --remove-bpe 'sentencepiece'\
|
| 93 |
+
--batch-size 32 \
|
| 94 |
+
--encoder-langtok "src" \
|
| 95 |
+
--decoder-langtok \
|
| 96 |
+
--lang-dict "$lang_list" \
|
| 97 |
+
--lang-pairs "$lang_pairs" > ${source_lang}_${target_lang}.txt
|
| 98 |
+
```
|
| 99 |
+
Fairseq will generate translation into a file {source_lang}_${target_lang}.txt with sacreblue at the end.
|
| 100 |
+
|
| 101 |
+
You can also use costomized tokenizer to compare the performance with the literature. For example, you get a tokenizer [here](https://github.com/rsennrich/wmt16-scripts) and do the following:
|
| 102 |
+
```bash
|
| 103 |
+
TOKENIZER=<path to a customized tokenizer for decoding evaluation>
|
| 104 |
+
TOK_CMD=<"$TOKENIZER $target_lang" or cat for sacrebleu>
|
| 105 |
+
|
| 106 |
+
cat {source_lang}_${target_lang}.txt | grep -P "^H" |sort -V |cut -f 3- |$TOK_CMD > ${source_lang}_${target_lang}.hyp
|
| 107 |
+
cat {source_lang}_${target_lang}.txt | grep -P "^T" |sort -V |cut -f 2- |$TOK_CMD > ${source_lang}_${target_lang}.ref
|
| 108 |
+
sacrebleu -tok 'none' -s 'none' ${source_lang}_${target_lang}.ref < ${source_lang}_${target_lang}.hyp
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
## Citation
|
| 114 |
+
|
| 115 |
+
```bibtex
|
| 116 |
+
@article{tang2020multilingual,
|
| 117 |
+
title={Multilingual Translation with Extensible Multilingual Pretraining and Finetuning},
|
| 118 |
+
author={Yuqing Tang and Chau Tran and Xian Li and Peng-Jen Chen and Naman Goyal and Vishrav Chaudhary and Jiatao Gu and Angela Fan},
|
| 119 |
+
year={2020},
|
| 120 |
+
eprint={2008.00401},
|
| 121 |
+
archivePrefix={arXiv},
|
| 122 |
+
primaryClass={cs.CL}
|
| 123 |
+
}
|
| 124 |
+
```
|
fairseq-0.10.2/examples/multilingual/finetune_multilingual_model.sh
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
path_2_data=$1 # <path to data> which contains binarized data for each directions
|
| 4 |
+
lang_list=$2 # <path to a file which contains a list of languages separted by new lines>
|
| 5 |
+
lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en"
|
| 6 |
+
# pretrained can be an mBART pretrained model as well
|
| 7 |
+
pretrained_model=$4 #<path to a pretrained model>
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
fairseq-train "$path_2_data" \
|
| 11 |
+
--encoder-normalize-before --decoder-normalize-before \
|
| 12 |
+
--arch transformer --layernorm-embedding \
|
| 13 |
+
--task translation_multi_simple_epoch \
|
| 14 |
+
--finetune-from-model "$pretrained_model" \
|
| 15 |
+
--sampling-method "temperature" \
|
| 16 |
+
--sampling-temperature "1.5" \
|
| 17 |
+
--encoder-langtok "src" \
|
| 18 |
+
--decoder-langtok \
|
| 19 |
+
--lang-dict "$lang_list" \
|
| 20 |
+
--lang-pairs "$lang_pairs" \
|
| 21 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
|
| 22 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
| 23 |
+
--lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \
|
| 24 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
| 25 |
+
--max-tokens 1024 --update-freq 2 \
|
| 26 |
+
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
|
| 27 |
+
--seed 222 --log-format simple --log-interval 2
|
fairseq-0.10.2/examples/multilingual/multilingual_fairseq_gen.sh
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
lang_pairs="en-fr,en-cs,fr-en,cs-en"
|
| 4 |
+
path_2_data=$1 # <path to data>
|
| 5 |
+
lang_list=$2 # <path to a file which contains list of languages separted by new lines>
|
| 6 |
+
model=$3 # <path to a trained model>
|
| 7 |
+
source_lang=cs
|
| 8 |
+
target_lang=en
|
| 9 |
+
|
| 10 |
+
fairseq-generate "$path_2_data" \
|
| 11 |
+
--path "$model" \
|
| 12 |
+
--task translation_multi_simple_epoch \
|
| 13 |
+
--gen-subset test \
|
| 14 |
+
--source-lang "$source_lang" \
|
| 15 |
+
--target-lang "$target_lang" \
|
| 16 |
+
--sacrebleu --remove-bpe 'sentencepiece'\
|
| 17 |
+
--batch-size 32 \
|
| 18 |
+
--encoder-langtok "src" \
|
| 19 |
+
--decoder-langtok \
|
| 20 |
+
--lang-dict "$lang_list" \
|
| 21 |
+
--lang-pairs "$lang_pairs"
|
fairseq-0.10.2/examples/multilingual/train_multilingual_model.sh
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
path_2_data=$1 # <path to data> which contains binarized data for each directions
|
| 4 |
+
lang_list=$2 # <path to a file which contains a list of languages separted by new lines>
|
| 5 |
+
lang_pairs=$3 #a list language pairs to train multilingual models, e.g. "en-fr,en-cs,fr-en,cs-en"
|
| 6 |
+
|
| 7 |
+
fairseq-train "$path_2_data" \
|
| 8 |
+
--encoder-normalize-before --decoder-normalize-before \
|
| 9 |
+
--arch transformer --layernorm-embedding \
|
| 10 |
+
--task translation_multi_simple_epoch \
|
| 11 |
+
--sampling-method "temperature" \
|
| 12 |
+
--sampling-temperature 1.5 \
|
| 13 |
+
--encoder-langtok "src" \
|
| 14 |
+
--decoder-langtok \
|
| 15 |
+
--lang-dict "$lang_list" \
|
| 16 |
+
--lang-pairs "$lang_pairs" \
|
| 17 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
|
| 18 |
+
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
|
| 19 |
+
--lr-scheduler inverse_sqrt --lr 3e-05 --min-lr -1 --warmup-updates 2500 --max-update 40000 \
|
| 20 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
|
| 21 |
+
--max-tokens 1024 --update-freq 2 \
|
| 22 |
+
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
|
| 23 |
+
--seed 222 --log-format simple --log-interval 2
|
fairseq-0.10.2/examples/nonautoregressive_translation/README.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Non-autoregressive Neural Machine Translation (NAT)
|
| 2 |
+
|
| 3 |
+
This page mainly includes instructions for reproducing results from the following papers
|
| 4 |
+
* [Levenshtein Transformer (Gu et al., 2019)](https://arxiv.org/abs/1905.11006).
|
| 5 |
+
* [Understanding Knowledge Distillation in Non-autoregressive Machine Translation (Zhou et al., 2019)](https://arxiv.org/abs/1911.02727).
|
| 6 |
+
|
| 7 |
+
We also provided our own implementations for several popular non-autoregressive-based models as reference:<br>
|
| 8 |
+
* [Non-Autoregressive Neural Machine Translation (Gu et al., 2017)](https://arxiv.org/abs/1711.02281)<br>
|
| 9 |
+
* [Deterministic Non-Autoregressive Neural Sequence Modeling by Iterative Refinement (Lee et al., 2018)](https://arxiv.org/abs/1802.06901)<br>
|
| 10 |
+
* [Insertion Transformer: Flexible Sequence Generation via Insertion Operations (Stern et al., 2019)](https://arxiv.org/abs/1902.03249)<br>
|
| 11 |
+
* [Mask-Predict: Parallel Decoding of Conditional Masked Language Models (Ghazvininejad et al., 2019)](https://arxiv.org/abs/1904.09324v2)<br>
|
| 12 |
+
* [Fast Structured Decoding for Sequence Models (Sun et al., 2019)](https://arxiv.org/abs/1910.11555)
|
| 13 |
+
|
| 14 |
+
## Dataset
|
| 15 |
+
|
| 16 |
+
First, follow the [instructions to download and preprocess the WMT'14 En-De dataset](../translation#wmt14-english-to-german-convolutional).
|
| 17 |
+
Make sure to learn a joint vocabulary by passing the `--joined-dictionary` option to `fairseq-preprocess`.
|
| 18 |
+
|
| 19 |
+
### Knowledge Distillation
|
| 20 |
+
Following [Gu et al. 2019](https://arxiv.org/abs/1905.11006), [knowledge distillation](https://arxiv.org/abs/1606.07947) from an autoregressive model can effectively simplify the training data distribution, which is sometimes essential for NAT-based models to learn good translations.
|
| 21 |
+
The easiest way of performing distillation is to follow the [instructions of training a standard transformer model](../translation) on the same data, and then decode the training set to produce a distillation dataset for NAT.
|
| 22 |
+
|
| 23 |
+
### Download
|
| 24 |
+
We also provided the preprocessed [original](http://dl.fbaipublicfiles.com/nat/original_dataset.zip) and [distillation](http://dl.fbaipublicfiles.com/nat/distill_dataset.zip) datasets. Please build the binarized dataset on your own.
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Train a model
|
| 28 |
+
|
| 29 |
+
Then we can train a nonautoregressive model using the `translation_lev` task and a new criterion `nat_loss`.
|
| 30 |
+
Use the `--noise` flag to specify the input noise used on the target sentences.
|
| 31 |
+
In default, we run the task for *Levenshtein Transformer*, with `--noise='random_delete'`. Full scripts to run other models can also be found [here](./scripts.md).
|
| 32 |
+
|
| 33 |
+
The following command will train a *Levenshtein Transformer* on the binarized dataset.
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
fairseq-train \
|
| 37 |
+
data-bin/wmt14_en_de_distill \
|
| 38 |
+
--save-dir checkpoints \
|
| 39 |
+
--ddp-backend=no_c10d \
|
| 40 |
+
--task translation_lev \
|
| 41 |
+
--criterion nat_loss \
|
| 42 |
+
--arch levenshtein_transformer \
|
| 43 |
+
--noise random_delete \
|
| 44 |
+
--share-all-embeddings \
|
| 45 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 46 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 47 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 48 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 49 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 50 |
+
--decoder-learned-pos \
|
| 51 |
+
--encoder-learned-pos \
|
| 52 |
+
--apply-bert-init \
|
| 53 |
+
--log-format 'simple' --log-interval 100 \
|
| 54 |
+
--fixed-validation-seed 7 \
|
| 55 |
+
--max-tokens 8000 \
|
| 56 |
+
--save-interval-updates 10000 \
|
| 57 |
+
--max-update 300000
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
## Translate
|
| 61 |
+
|
| 62 |
+
Once a model is trained, we can generate translations using an `iterative_refinement_generator` which will based on the model's initial output and iteratively read and greedily refine the translation until (1) the model predicts the same translations for two consecutive iterations; or (2) the generator reaches the maximum iterations (`--iter-decode-max-iter`). Use `--print-step` to check the actual # of iteration for each sentence.
|
| 63 |
+
|
| 64 |
+
For *Levenshtein Transformer*, it sometimes helps to apply a `--iter-decode-eos-penalty` (typically, 0~3) to penalize the model finishing generation too early and generating too short translations.
|
| 65 |
+
|
| 66 |
+
For example, to generate with `--iter-decode-max-iter=9`:
|
| 67 |
+
```bash
|
| 68 |
+
fairseq-generate \
|
| 69 |
+
data-bin/wmt14_en_de_distill \
|
| 70 |
+
--gen-subset test \
|
| 71 |
+
--task translation_lev \
|
| 72 |
+
--path checkpoints/checkpoint_best.pt \
|
| 73 |
+
--iter-decode-max-iter 9 \
|
| 74 |
+
--iter-decode-eos-penalty 0 \
|
| 75 |
+
--beam 1 --remove-bpe \
|
| 76 |
+
--print-step \
|
| 77 |
+
--batch-size 400
|
| 78 |
+
```
|
| 79 |
+
In the end of the generation, we can see the tokenized BLEU score for the translation.
|
| 80 |
+
|
| 81 |
+
## Advanced Decoding Methods
|
| 82 |
+
### Ensemble
|
| 83 |
+
The NAT models use special implementations of [ensembling](https://github.com/fairinternal/fairseq-py/blob/b98d88da52f2f21f1b169bab8c70c1c4ca19a768/fairseq/sequence_generator.py#L522) to support iterative refinement and a variety of parallel operations in different models, while it shares the same API as standard autoregressive models as follows:
|
| 84 |
+
```bash
|
| 85 |
+
fairseq-generate \
|
| 86 |
+
data-bin/wmt14_en_de_distill \
|
| 87 |
+
--gen-subset test \
|
| 88 |
+
--task translation_lev \
|
| 89 |
+
--path checkpoint_1.pt:checkpoint_2.pt:checkpoint_3.pt \
|
| 90 |
+
--iter-decode-max-iter 9 \
|
| 91 |
+
--iter-decode-eos-penalty 0 \
|
| 92 |
+
--beam 1 --remove-bpe \
|
| 93 |
+
--print-step \
|
| 94 |
+
--batch-size 400
|
| 95 |
+
```
|
| 96 |
+
We use ``:`` to split multiple models. Note that, not all NAT models support ensembling for now.
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
### Length-beam
|
| 100 |
+
For models that predict lengths before decoding (e.g. the vanilla NAT, Mask-Predict, etc), it is possible to improve the translation quality by varying the target lengths around the predicted value, and translating the same example multiple times in parallel. We can select the best translation with the highest scores defined by your model's output.
|
| 101 |
+
|
| 102 |
+
Note that, not all models support length beams. For models which dynamically change the lengths (e.g. *Insertion Transformer*, *Levenshtein Transformer*), the same trick does not apply.
|
| 103 |
+
|
| 104 |
+
### Re-ranking
|
| 105 |
+
If the model generates multiple translations with length beam, we can also introduce an autoregressive model to rerank the translations considering scoring from an autoregressive model is much faster than decoding from that.
|
| 106 |
+
|
| 107 |
+
For example, to generate translations with length beam and reranking,
|
| 108 |
+
```bash
|
| 109 |
+
fairseq-generate \
|
| 110 |
+
data-bin/wmt14_en_de_distill \
|
| 111 |
+
--gen-subset test \
|
| 112 |
+
--task translation_lev \
|
| 113 |
+
--path checkpoints/checkpoint_best.pt:at_checkpoints/checkpoint_best.pt \
|
| 114 |
+
--iter-decode-max-iter 9 \
|
| 115 |
+
--iter-decode-eos-penalty 0 \
|
| 116 |
+
--iter-decode-with-beam 9 \
|
| 117 |
+
--iter-decode-with-external-reranker \
|
| 118 |
+
--beam 1 --remove-bpe \
|
| 119 |
+
--print-step \
|
| 120 |
+
--batch-size 100
|
| 121 |
+
```
|
| 122 |
+
Note that we need to make sure the autoregressive model shares the same vocabulary as our target non-autoregressive model.
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
## Citation
|
| 126 |
+
|
| 127 |
+
```bibtex
|
| 128 |
+
@incollection{NIPS2019_9297,
|
| 129 |
+
title = {Levenshtein Transformer},
|
| 130 |
+
author = {Gu, Jiatao and Wang, Changhan and Zhao, Junbo},
|
| 131 |
+
booktitle = {Advances in Neural Information Processing Systems 32},
|
| 132 |
+
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
|
| 133 |
+
pages = {11179--11189},
|
| 134 |
+
year = {2019},
|
| 135 |
+
publisher = {Curran Associates, Inc.},
|
| 136 |
+
url = {http://papers.nips.cc/paper/9297-levenshtein-transformer.pdf}
|
| 137 |
+
}
|
| 138 |
+
```
|
| 139 |
+
```bibtex
|
| 140 |
+
@article{zhou2019understanding,
|
| 141 |
+
title={Understanding Knowledge Distillation in Non-autoregressive Machine Translation},
|
| 142 |
+
author={Zhou, Chunting and Neubig, Graham and Gu, Jiatao},
|
| 143 |
+
journal={arXiv preprint arXiv:1911.02727},
|
| 144 |
+
year={2019}
|
| 145 |
+
}
|
| 146 |
+
```
|
fairseq-0.10.2/examples/nonautoregressive_translation/scripts.md
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Examples of Training scripts for Non-autoregressive Machine Translation models
|
| 2 |
+
|
| 3 |
+
### Non-autoregressive Transformer (NAT, Gu et al., 2017)
|
| 4 |
+
Note that we need to have an additional module to perform "length prediction" (`--length-loss-factor`) before generating the whole sequence.
|
| 5 |
+
```bash
|
| 6 |
+
fairseq-train \
|
| 7 |
+
data-bin/wmt14_en_de_distill \
|
| 8 |
+
--save-dir checkpoints \
|
| 9 |
+
--ddp-backend=no_c10d \
|
| 10 |
+
--task translation_lev \
|
| 11 |
+
--criterion nat_loss \
|
| 12 |
+
--arch nonautoregressive_transformer \
|
| 13 |
+
--noise full_mask \
|
| 14 |
+
--share-all-embeddings \
|
| 15 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 16 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 17 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 18 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 19 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 20 |
+
--decoder-learned-pos \
|
| 21 |
+
--encoder-learned-pos \
|
| 22 |
+
--pred-length-offset \
|
| 23 |
+
--length-loss-factor 0.1 \
|
| 24 |
+
--apply-bert-init \
|
| 25 |
+
--log-format 'simple' --log-interval 100 \
|
| 26 |
+
--fixed-validation-seed 7 \
|
| 27 |
+
--max-tokens 8000 \
|
| 28 |
+
--save-interval-updates 10000 \
|
| 29 |
+
--max-update 300000
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
### Fast Structured Decoding for Sequence Models (NAT-CRF, Sun et al., 2019)
|
| 33 |
+
Note that we implemented a low-rank appromixated CRF model by setting `--crf-lowrank-approx=32` and `--crf-beam-approx=64` as discribed in the original paper. All other settings are the same as the vanilla NAT model.
|
| 34 |
+
```bash
|
| 35 |
+
fairseq-train \
|
| 36 |
+
data-bin/wmt14_en_de_distill \
|
| 37 |
+
--save-dir checkpoints \
|
| 38 |
+
--ddp-backend=no_c10d \
|
| 39 |
+
--task translation_lev \
|
| 40 |
+
--criterion nat_loss \
|
| 41 |
+
--arch nacrf_transformer \
|
| 42 |
+
--noise full_mask \
|
| 43 |
+
--share-all-embeddings \
|
| 44 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 45 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 46 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 47 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 48 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 49 |
+
--decoder-learned-pos \
|
| 50 |
+
--encoder-learned-pos \
|
| 51 |
+
--pred-length-offset \
|
| 52 |
+
--length-loss-factor 0.1 \
|
| 53 |
+
--word-ins-loss-factor 0.5 \
|
| 54 |
+
--crf-lowrank-approx 32 \
|
| 55 |
+
--crf-beam-approx 64 \
|
| 56 |
+
--apply-bert-init \
|
| 57 |
+
--log-format 'simple' --log-interval 100 \
|
| 58 |
+
--fixed-validation-seed 7 \
|
| 59 |
+
--max-tokens 8000 \
|
| 60 |
+
--save-interval-updates 10000 \
|
| 61 |
+
--max-update 300000
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
### Non-autoregressive Transformer with Iterative Refinement (iNAT, Lee et al., 2018)
|
| 66 |
+
Note that `--train-step` means how many iterations of refinement we used during training, and `--dae-ratio` controls the ratio of denoising auto-encoder training described in the original paper.
|
| 67 |
+
```bash
|
| 68 |
+
fairseq-train \
|
| 69 |
+
data-bin/wmt14_en_de_distill \
|
| 70 |
+
--save-dir checkpoints \
|
| 71 |
+
--ddp-backend=no_c10d \
|
| 72 |
+
--task translation_lev \
|
| 73 |
+
--criterion nat_loss \
|
| 74 |
+
--arch iterative_nonautoregressive_transformer \
|
| 75 |
+
--noise full_mask \
|
| 76 |
+
--share-all-embeddings \
|
| 77 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 78 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 79 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 80 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 81 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 82 |
+
--decoder-learned-pos \
|
| 83 |
+
--encoder-learned-pos \
|
| 84 |
+
--pred-length-offset \
|
| 85 |
+
--length-loss-factor 0.1 \
|
| 86 |
+
--train-step 4 \
|
| 87 |
+
--dae-ratio 0.5 \
|
| 88 |
+
--stochastic-approx \
|
| 89 |
+
--apply-bert-init \
|
| 90 |
+
--log-format 'simple' --log-interval 100 \
|
| 91 |
+
--fixed-validation-seed 7 \
|
| 92 |
+
--max-tokens 8000 \
|
| 93 |
+
--save-interval-updates 10000 \
|
| 94 |
+
--max-update 300000
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Insertion Transformer (InsT, Stern et al., 2019)
|
| 98 |
+
Note that we need to specify the "slot-loss" (uniform or balanced tree) described in the original paper. Here we use `--label-tau` to control the temperature.
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
fairseq-train \
|
| 102 |
+
data-bin/wmt14_en_de_distill \
|
| 103 |
+
--save-dir checkpoints \
|
| 104 |
+
--ddp-backend=no_c10d \
|
| 105 |
+
--task translation_lev \
|
| 106 |
+
--criterion nat_loss \
|
| 107 |
+
--arch insertion_transformer \
|
| 108 |
+
--noise random_delete \
|
| 109 |
+
--share-all-embeddings \
|
| 110 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 111 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 112 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 113 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 114 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 115 |
+
--decoder-learned-pos \
|
| 116 |
+
--encoder-learned-pos \
|
| 117 |
+
--apply-bert-init \
|
| 118 |
+
--log-format 'simple' --log-interval 100 \
|
| 119 |
+
--fixed-validation-seed 7 \
|
| 120 |
+
--max-tokens 8000 \
|
| 121 |
+
--save-interval-updates 10000 \
|
| 122 |
+
--max-update 300000
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
### Mask Predict (CMLM, Ghazvininejad et al., 2019)
|
| 127 |
+
```bash
|
| 128 |
+
fairseq-train \
|
| 129 |
+
data-bin/wmt14_en_de_distill \
|
| 130 |
+
--save-dir checkpoints \
|
| 131 |
+
--ddp-backend=no_c10d \
|
| 132 |
+
--task translation_lev \
|
| 133 |
+
--criterion nat_loss \
|
| 134 |
+
--arch cmlm_transformer \
|
| 135 |
+
--noise random_mask \
|
| 136 |
+
--share-all-embeddings \
|
| 137 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 138 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 139 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 140 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 141 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 142 |
+
--decoder-learned-pos \
|
| 143 |
+
--encoder-learned-pos \
|
| 144 |
+
--apply-bert-init \
|
| 145 |
+
--log-format 'simple' --log-interval 100 \
|
| 146 |
+
--fixed-validation-seed 7 \
|
| 147 |
+
--max-tokens 8000 \
|
| 148 |
+
--save-interval-updates 10000 \
|
| 149 |
+
--max-update 300000
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
### Levenshtein Transformer (LevT, Gu et al., 2019)
|
| 156 |
+
```bash
|
| 157 |
+
fairseq-train \
|
| 158 |
+
data-bin/wmt14_en_de_distill \
|
| 159 |
+
--save-dir checkpoints \
|
| 160 |
+
--ddp-backend=no_c10d \
|
| 161 |
+
--task translation_lev \
|
| 162 |
+
--criterion nat_loss \
|
| 163 |
+
--arch levenshtein_transformer \
|
| 164 |
+
--noise random_delete \
|
| 165 |
+
--share-all-embeddings \
|
| 166 |
+
--optimizer adam --adam-betas '(0.9,0.98)' \
|
| 167 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt \
|
| 168 |
+
--min-lr '1e-09' --warmup-updates 10000 \
|
| 169 |
+
--warmup-init-lr '1e-07' --label-smoothing 0.1 \
|
| 170 |
+
--dropout 0.3 --weight-decay 0.01 \
|
| 171 |
+
--decoder-learned-pos \
|
| 172 |
+
--encoder-learned-pos \
|
| 173 |
+
--apply-bert-init \
|
| 174 |
+
--log-format 'simple' --log-interval 100 \
|
| 175 |
+
--fixed-validation-seed 7 \
|
| 176 |
+
--max-tokens 8000 \
|
| 177 |
+
--save-interval-updates 10000 \
|
| 178 |
+
--max-update 300000
|
| 179 |
+
```
|
fairseq-0.10.2/examples/pay_less_attention_paper/README.md
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Pay Less Attention with Lightweight and Dynamic Convolutions (Wu et al., 2019)
|
| 2 |
+
|
| 3 |
+
This page contains pointers to pre-trained models as well as instructions on how to train new models for [our paper](https://arxiv.org/abs/1901.10430).
|
| 4 |
+
|
| 5 |
+
## Citation:
|
| 6 |
+
```bibtex
|
| 7 |
+
@inproceedings{wu2018pay,
|
| 8 |
+
title = {Pay Less Attention with Lightweight and Dynamic Convolutions},
|
| 9 |
+
author = {Felix Wu and Angela Fan and Alexei Baevski and Yann Dauphin and Michael Auli},
|
| 10 |
+
booktitle = {International Conference on Learning Representations},
|
| 11 |
+
year = {2019},
|
| 12 |
+
url = {https://arxiv.org/abs/1901.10430},
|
| 13 |
+
}
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
## Translation
|
| 17 |
+
|
| 18 |
+
### Pre-trained models
|
| 19 |
+
For some datasets we release models without GLUs which are faster at inference.
|
| 20 |
+
|
| 21 |
+
Model | Description | Dataset | Download
|
| 22 |
+
---|---|---|---
|
| 23 |
+
`lightconv.no_glu.iwslt14.de-en` | LightConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.lightconv.tar.gz) <br> IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2)
|
| 24 |
+
`dynamicconv.no_glu.iwslt14.de-en` | DynamicConv (without GLUs) | [IWSLT14 German-English](https://wit3.fbk.eu/archive/2014-01/texts/de/en/de-en.tgz) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/iwslt14.de-en.dynamicconv.tar.gz) <br> IWSLT14 test: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/iwslt14.de-en.test.tar.bz2)
|
| 25 |
+
`lightconv.no_glu.wmt16.en-de` | LightConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
|
| 26 |
+
`dynamicconv.no_glu.wmt16.en-de` | DynamicConv (without GLUs) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
|
| 27 |
+
`lightconv.glu.wmt16.en-de` | LightConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.lightconv-glu.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
|
| 28 |
+
`dynamicconv.glu.wmt16.en-de` | DynamicConv | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt16.en-de.joined-dict.dynamicconv-glu.tar.gz) <br> newstest2014 (shared vocab): <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
|
| 29 |
+
`lightconv.glu.wmt14.en-fr` | LightConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.lightconv-glu.tar.gz) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
|
| 30 |
+
`dynamicconv.glu.wmt14.en-fr` | DynamicConv | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt14.en-fr.joined-dict.dynamicconv-glu.tar.gz) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
|
| 31 |
+
`lightconv.glu.wmt17.zh-en` | LightConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.lightconv-glu.tar.gz) <br> newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2)
|
| 32 |
+
`dynamicconv.glu.wmt17.zh-en` | DynamicConv | [WMT17 Chinese-English](http://statmt.org/wmt17/translation-task.html#Download) | model: <br> [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/dynamicconv/wmt17.zh-en.dynamicconv-glu.tar.gz) <br> newstest2017: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt17.zh-en.newstest2017.tar.bz2)
|
| 33 |
+
|
| 34 |
+
### Memory-Efficient CUDA Kernels
|
| 35 |
+
|
| 36 |
+
Since the PyTorch implementations of Light/Dynamic conv are quite memory intensive, we have developed CUDA kernels that implement the light and dynamic convolution operator in a memory-efficient and performant manner. For large sequence lengths, these kernels save about 50% memory compared to the PyTorch equivalent.
|
| 37 |
+
|
| 38 |
+
To install the kernels, use the commands below. Once installed, they will automatically be used in place of the PyTorch implementations whenever a light or dynamic convolution is used.
|
| 39 |
+
|
| 40 |
+
```sh
|
| 41 |
+
# to install lightconv
|
| 42 |
+
cd fairseq/modules/lightconv_layer
|
| 43 |
+
python cuda_function_gen.py
|
| 44 |
+
python setup.py install
|
| 45 |
+
|
| 46 |
+
# to install dynamicconv
|
| 47 |
+
cd fairseq/modules/dynamicconv_layer
|
| 48 |
+
python cuda_function_gen.py
|
| 49 |
+
python setup.py install
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### Example usage (torch.hub)
|
| 53 |
+
|
| 54 |
+
We require a few additional Python dependencies for preprocessing:
|
| 55 |
+
```bash
|
| 56 |
+
pip install sacremoses subword_nmt
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Interactive translation via PyTorch Hub:
|
| 60 |
+
```python
|
| 61 |
+
import torch
|
| 62 |
+
|
| 63 |
+
# List available models
|
| 64 |
+
torch.hub.list('pytorch/fairseq') # [..., 'lightconv.glu.wmt17.zh-en', ... ]
|
| 65 |
+
|
| 66 |
+
# Load a transformer trained on WMT'16 En-De
|
| 67 |
+
zh2en = torch.hub.load('pytorch/fairseq', 'lightconv.glu.wmt17.zh-en', tokenizer='moses', bpe='subword_nmt')
|
| 68 |
+
|
| 69 |
+
# The underlying model is available under the *models* attribute
|
| 70 |
+
assert isinstance(zh2en.models[0], fairseq.models.lightconv.LightConvModel)
|
| 71 |
+
|
| 72 |
+
# Translate a sentence
|
| 73 |
+
zh2en.translate('你好 世界')
|
| 74 |
+
# 'Hello World'
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Loading custom models:
|
| 78 |
+
```python
|
| 79 |
+
from fairseq.models.lightconv import LightConvModel
|
| 80 |
+
en2fr = LightConvModel.from_pretrained(
|
| 81 |
+
'/path/to/checkpoints',
|
| 82 |
+
checkpoint_file='checkpoint_best.pt',
|
| 83 |
+
data_name_or_path='data-bin/wmt14_en_fr',
|
| 84 |
+
bpe='subword_nmt',
|
| 85 |
+
bpe_codes='data-bin/wmt14_en_fr/en.code'
|
| 86 |
+
)
|
| 87 |
+
en2fr.translate('Hello world!')
|
| 88 |
+
# 'Bonjour le monde'
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### Preprocessing the training datasets
|
| 92 |
+
|
| 93 |
+
Please follow the instructions in [`examples/translation/README.md`](../translation/README.md) to preprocess the data.
|
| 94 |
+
|
| 95 |
+
### Training and evaluation options:
|
| 96 |
+
To use the model without GLU, please set `--encoder-glu 0 --decoder-glu 0`.
|
| 97 |
+
For LightConv, please use `--encoder-conv-type lightweight --decoder-conv-type lightweight`, otherwise the default is DynamicConv.
|
| 98 |
+
For best BLEU results, lenpen may need to be manually tuned.
|
| 99 |
+
|
| 100 |
+
To use the CUDA kernels, first install the PyTorch modules using the commands
|
| 101 |
+
above. Once the CUDA modules are installed, they will automatically be used
|
| 102 |
+
instead of the PyTorch modules.
|
| 103 |
+
|
| 104 |
+
### IWSLT14 De-En
|
| 105 |
+
Training and evaluating DynamicConv (without GLU) on a GPU:
|
| 106 |
+
```sh
|
| 107 |
+
# Training
|
| 108 |
+
SAVE="save/dynamic_conv_iwslt"
|
| 109 |
+
mkdir -p $SAVE
|
| 110 |
+
CUDA_VISIBLE_DEVICES=0 $(which fairseq-train) data-bin/iwslt14.tokenized.de-en \
|
| 111 |
+
--clip-norm 0 --optimizer adam --lr 0.0005 \
|
| 112 |
+
--source-lang de --target-lang en --max-tokens 4000 --no-progress-bar \
|
| 113 |
+
--log-interval 100 --min-lr '1e-09' --weight-decay 0.0001 \
|
| 114 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 115 |
+
--lr-scheduler inverse_sqrt \
|
| 116 |
+
--ddp-backend=no_c10d \
|
| 117 |
+
--max-update 50000 --warmup-updates 4000 --warmup-init-lr '1e-07' \
|
| 118 |
+
--adam-betas '(0.9, 0.98)' --keep-last-epochs 10 \
|
| 119 |
+
-a lightconv_iwslt_de_en --save-dir $SAVE \
|
| 120 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
|
| 121 |
+
--encoder-glu 0 --decoder-glu 0
|
| 122 |
+
python scripts/average_checkpoints.py --inputs $SAVE \
|
| 123 |
+
--num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"
|
| 124 |
+
|
| 125 |
+
# Evaluation
|
| 126 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/iwslt14.tokenized.de-en --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 128 --beam 4 --remove-bpe --lenpen 1 --gen-subset test --quiet
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
### WMT16 En-De
|
| 130 |
+
Training and evaluating DynamicConv (with GLU) on WMT16 En-De using cosine scheduler on one machine with 8 V100 GPUs:
|
| 131 |
+
```sh
|
| 132 |
+
# Training
|
| 133 |
+
SAVE="save/dynamic_conv_wmt16en2de"
|
| 134 |
+
mkdir -p $SAVE
|
| 135 |
+
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
|
| 136 |
+
data-bin/wmt16_en_de_bpe32k --fp16 --log-interval 100 --no-progress-bar \
|
| 137 |
+
--max-update 30000 --share-all-embeddings --optimizer adam \
|
| 138 |
+
--adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \
|
| 139 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 140 |
+
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
|
| 141 |
+
--ddp-backend=no_c10d --max-tokens 3584 \
|
| 142 |
+
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
|
| 143 |
+
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
|
| 144 |
+
--t-mult 1 --lr-period-updates 20000 \
|
| 145 |
+
--arch lightconv_wmt_en_de_big --save-dir $SAVE \
|
| 146 |
+
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
|
| 147 |
+
--encoder-glu 1 --decoder-glu 1
|
| 148 |
+
|
| 149 |
+
# Evaluation
|
| 150 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt16.en-de.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.5 --gen-subset test > wmt16_gen.txt
|
| 151 |
+
bash scripts/compound_split_bleu.sh wmt16_gen.txt
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
### WMT14 En-Fr
|
| 155 |
+
Training DynamicConv (with GLU) on WMT14 En-Fr using cosine scheduler on one machine with 8 V100 GPUs:
|
| 156 |
+
```sh
|
| 157 |
+
# Training
|
| 158 |
+
SAVE="save/dynamic_conv_wmt14en2fr"
|
| 159 |
+
mkdir -p $SAVE
|
| 160 |
+
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
|
| 161 |
+
data-bin/wmt14_en_fr --fp16 --log-interval 100 --no-progress-bar \
|
| 162 |
+
--max-update 30000 --share-all-embeddings --optimizer adam \
|
| 163 |
+
--adam-betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 \
|
| 164 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 165 |
+
--min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
|
| 166 |
+
--ddp-backend=no_c10d --max-tokens 3584 \
|
| 167 |
+
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
|
| 168 |
+
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
|
| 169 |
+
--t-mult 1 --lr-period-updates 70000 \
|
| 170 |
+
--arch lightconv_wmt_en_fr_big --save-dir $SAVE \
|
| 171 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \
|
| 172 |
+
--encoder-glu 1 --decoder-glu 1
|
| 173 |
+
|
| 174 |
+
# Evaluation
|
| 175 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-generate data-bin/wmt14.en-fr.joined-dict.newstest2014 --path "${SAVE}/checkpoint_best.pt" --batch-size 128 --beam 5 --remove-bpe --lenpen 0.9 --gen-subset test
|
| 176 |
+
```
|
fairseq-0.10.2/examples/pointer_generator/README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Transformer with Pointer-Generator Network
|
| 2 |
+
|
| 3 |
+
This page describes the `transformer_pointer_generator` model that incorporates
|
| 4 |
+
a pointing mechanism in the Transformer model that facilitates copying of input
|
| 5 |
+
words to the output. This architecture is described in [Enarvi et al. (2020)](https://www.aclweb.org/anthology/2020.nlpmc-1.4/).
|
| 6 |
+
|
| 7 |
+
## Background
|
| 8 |
+
|
| 9 |
+
The pointer-generator network was introduced in [See et al. (2017)](https://arxiv.org/abs/1704.04368)
|
| 10 |
+
for RNN encoder-decoder attention models. A similar mechanism can be
|
| 11 |
+
incorporated in a Transformer model by reusing one of the many attention
|
| 12 |
+
distributions for pointing. The attention distribution over the input words is
|
| 13 |
+
interpolated with the normal output distribution over the vocabulary words. This
|
| 14 |
+
allows the model to generate words that appear in the input, even if they don't
|
| 15 |
+
appear in the vocabulary, helping especially with small vocabularies.
|
| 16 |
+
|
| 17 |
+
## Implementation
|
| 18 |
+
|
| 19 |
+
The mechanism for copying out-of-vocabulary words from the input has been
|
| 20 |
+
implemented differently to See et al. In their [implementation](https://github.com/abisee/pointer-generator)
|
| 21 |
+
they convey the word identities through the model in order to be able to produce
|
| 22 |
+
words that appear in the input sequence but not in the vocabulary. A different
|
| 23 |
+
approach was taken in the Fairseq implementation to keep it self-contained in
|
| 24 |
+
the model file, avoiding any changes to the rest of the code base. Copying
|
| 25 |
+
out-of-vocabulary words is possible by pre-processing the input and
|
| 26 |
+
post-processing the output. This is described in detail in the next section.
|
| 27 |
+
|
| 28 |
+
## Usage
|
| 29 |
+
|
| 30 |
+
The training and evaluation procedure is outlined below. You can also find a
|
| 31 |
+
more detailed example for the XSum dataset on [this page](README.xsum.md).
|
| 32 |
+
|
| 33 |
+
##### 1. Create a vocabulary and extend it with source position markers
|
| 34 |
+
|
| 35 |
+
The pointing mechanism is especially helpful with small vocabularies, if we are
|
| 36 |
+
able to recover the identities of any out-of-vocabulary words that are copied
|
| 37 |
+
from the input. For this purpose, the model allows extending the vocabulary with
|
| 38 |
+
special tokens that can be used in place of `<unk>` tokens to identify different
|
| 39 |
+
input positions. For example, the user may add `<unk-0>`, `<unk-1>`, `<unk-2>`,
|
| 40 |
+
etc. to the end of the vocabulary, after the normal words. Below is an example
|
| 41 |
+
of how to create a vocabulary of 10000 most common words and add 1000 input
|
| 42 |
+
position markers.
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
vocab_size=10000
|
| 46 |
+
position_markers=1000
|
| 47 |
+
export LC_ALL=C
|
| 48 |
+
cat train.src train.tgt |
|
| 49 |
+
tr -s '[:space:]' '\n' |
|
| 50 |
+
sort |
|
| 51 |
+
uniq -c |
|
| 52 |
+
sort -k1,1bnr -k2 |
|
| 53 |
+
head -n "$((vocab_size - 4))" |
|
| 54 |
+
awk '{ print $2 " " $1 }' >dict.pg.txt
|
| 55 |
+
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
##### 2. Preprocess the text data
|
| 59 |
+
|
| 60 |
+
The idea is that any `<unk>` tokens in the text are replaced with `<unk-0>` if
|
| 61 |
+
it appears in the first input position, `<unk-1>` if it appears in the second
|
| 62 |
+
input position, and so on. This can be achieved using the `preprocess.py` script
|
| 63 |
+
that is provided in this directory.
|
| 64 |
+
|
| 65 |
+
##### 3. Train a model
|
| 66 |
+
|
| 67 |
+
The number of these special tokens is given to the model with the
|
| 68 |
+
`--source-position-markers` argument—the model simply maps all of these to the
|
| 69 |
+
same word embedding as `<unk>`.
|
| 70 |
+
|
| 71 |
+
The attention distribution that is used for pointing is selected using the
|
| 72 |
+
`--alignment-heads` and `--alignment-layer` command-line arguments in the same
|
| 73 |
+
way as with the `transformer_align` model.
|
| 74 |
+
|
| 75 |
+
##### 4. Generate text and postprocess it
|
| 76 |
+
|
| 77 |
+
When using the model to generate text, you want to preprocess the input text in
|
| 78 |
+
the same way that training data was processed, replacing out-of-vocabulary words
|
| 79 |
+
with `<unk-N>` tokens. If any of these tokens are copied to the output, the
|
| 80 |
+
actual words can be retrieved from the unprocessed input text. Any `<unk-N>`
|
| 81 |
+
token should be replaced with the word at position N in the original input
|
| 82 |
+
sequence. This can be achieved using the `postprocess.py` script.
|
fairseq-0.10.2/examples/pointer_generator/README.xsum.md
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Training a pointer-generator model on the Extreme Summarization dataset
|
| 2 |
+
|
| 3 |
+
##### 1. Download the Extreme Summarization data and preprocess it
|
| 4 |
+
|
| 5 |
+
Follow the instructions [here](https://github.com/EdinburghNLP/XSum) to obtain
|
| 6 |
+
the original Extreme Summarization dataset. You should have six files,
|
| 7 |
+
{train,validation,test}.{document,summary}.
|
| 8 |
+
|
| 9 |
+
##### 2. Create a vocabulary and extend it with source position markers
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
vocab_size=10000
|
| 13 |
+
position_markers=1000
|
| 14 |
+
export LC_ALL=C
|
| 15 |
+
cat train.document train.summary |
|
| 16 |
+
tr -s '[:space:]' '\n' |
|
| 17 |
+
sort |
|
| 18 |
+
uniq -c |
|
| 19 |
+
sort -k1,1bnr -k2 |
|
| 20 |
+
head -n "$((vocab_size - 4))" |
|
| 21 |
+
awk '{ print $2 " " $1 }' >dict.pg.txt
|
| 22 |
+
python3 -c "[print('<unk-{}> 0'.format(n)) for n in range($position_markers)]" >>dict.pg.txt
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
This creates the file dict.pg.txt that contains the 10k most frequent words,
|
| 26 |
+
followed by 1k source position markers:
|
| 27 |
+
|
| 28 |
+
```
|
| 29 |
+
the 4954867
|
| 30 |
+
. 4157552
|
| 31 |
+
, 3439668
|
| 32 |
+
to 2212159
|
| 33 |
+
a 1916857
|
| 34 |
+
of 1916820
|
| 35 |
+
and 1823350
|
| 36 |
+
...
|
| 37 |
+
<unk-0> 0
|
| 38 |
+
<unk-1> 0
|
| 39 |
+
<unk-2> 0
|
| 40 |
+
<unk-3> 0
|
| 41 |
+
<unk-4> 0
|
| 42 |
+
...
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
##### 2. Preprocess the text data
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
./preprocess.py --source train.document --target train.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out train.pg.src --target-out train.pg.tgt
|
| 49 |
+
./preprocess.py --source validation.document --target validation.summary --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out valid.pg.src --target-out valid.pg.tgt
|
| 50 |
+
./preprocess.py --source test.document --vocab <(cut -d' ' -f1 dict.pg.txt) --source-out test.pg.src
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
The data should now contain `<unk-N>` tokens in place of out-of-vocabulary words.
|
| 54 |
+
|
| 55 |
+
##### 3. Binarize the dataset:
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
fairseq-preprocess \
|
| 59 |
+
--source-lang src \
|
| 60 |
+
--target-lang tgt \
|
| 61 |
+
--trainpref train.pg \
|
| 62 |
+
--validpref valid.pg \
|
| 63 |
+
--destdir bin \
|
| 64 |
+
--workers 60 \
|
| 65 |
+
--srcdict dict.pg.txt \
|
| 66 |
+
--joined-dictionary
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
##### 3. Train a model
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
total_updates=20000
|
| 73 |
+
warmup_updates=500
|
| 74 |
+
lr=0.001
|
| 75 |
+
max_tokens=4096
|
| 76 |
+
update_freq=4
|
| 77 |
+
pointer_layer=-2
|
| 78 |
+
|
| 79 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 fairseq-train bin \
|
| 80 |
+
--user-dir examples/pointer_generator/pointer_generator_src \
|
| 81 |
+
--max-tokens "$max_tokens" \
|
| 82 |
+
--task translation \
|
| 83 |
+
--source-lang src --target-lang tgt \
|
| 84 |
+
--truncate-source \
|
| 85 |
+
--layernorm-embedding \
|
| 86 |
+
--share-all-embeddings \
|
| 87 |
+
--encoder-normalize-before \
|
| 88 |
+
--decoder-normalize-before \
|
| 89 |
+
--required-batch-size-multiple 1 \
|
| 90 |
+
--arch transformer_pointer_generator \
|
| 91 |
+
--alignment-layer "$pointer_layer" \
|
| 92 |
+
--alignment-heads 1 \
|
| 93 |
+
--source-position-markers 1000 \
|
| 94 |
+
--criterion label_smoothed_cross_entropy \
|
| 95 |
+
--label-smoothing 0.1 \
|
| 96 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 97 |
+
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \
|
| 98 |
+
--clip-norm 0.1 \
|
| 99 |
+
--lr-scheduler inverse_sqrt --lr "$lr" --max-update "$total_updates" --warmup-updates "$warmup_updates" \
|
| 100 |
+
--update-freq "$update_freq" \
|
| 101 |
+
--skip-invalid-size-inputs-valid-test
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Above we specify that our dictionary contains 1000 source position markers, and
|
| 105 |
+
that we want to use one attention head from the penultimate decoder layer for
|
| 106 |
+
pointing. It should run in 5.5 hours on one node with eight 32GB V100 GPUs. The
|
| 107 |
+
logged messages confirm that dictionary indices above 10000 will be mapped to
|
| 108 |
+
the `<unk>` embedding:
|
| 109 |
+
|
| 110 |
+
```
|
| 111 |
+
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [src] dictionary: 11000 types
|
| 112 |
+
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | [tgt] dictionary: 11000 types
|
| 113 |
+
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.src
|
| 114 |
+
2020-09-24 20:43:53 | INFO | fairseq.data.data_utils | loaded 11332 examples from: bin/valid.src-tgt.tgt
|
| 115 |
+
2020-09-24 20:43:53 | INFO | fairseq.tasks.translation | bin valid src-tgt 11332 examples
|
| 116 |
+
2020-09-24 20:43:53 | INFO | fairseq.models.transformer_pg | dictionary indices from 10000 to 10999 will be mapped to 3
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
##### 4. Summarize the test sequences
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
batch_size=32
|
| 123 |
+
beam_size=6
|
| 124 |
+
max_length=60
|
| 125 |
+
length_penalty=1.0
|
| 126 |
+
|
| 127 |
+
fairseq-interactive bin \
|
| 128 |
+
--user-dir examples/pointer_generator/pointer_generator_src \
|
| 129 |
+
--batch-size "$batch_size" \
|
| 130 |
+
--task translation \
|
| 131 |
+
--source-lang src --target-lang tgt \
|
| 132 |
+
--path checkpoints/checkpoint_last.pt \
|
| 133 |
+
--input test.pg.src \
|
| 134 |
+
--buffer-size 200 \
|
| 135 |
+
--max-len-a 0 \
|
| 136 |
+
--max-len-b "$max_length" \
|
| 137 |
+
--lenpen "$length_penalty" \
|
| 138 |
+
--beam "$beam_size" \
|
| 139 |
+
--skip-invalid-size-inputs-valid-test |
|
| 140 |
+
tee generate.out
|
| 141 |
+
grep ^H generate.out | cut -f 3- >generate.hyp
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
Now you should have the generated sequences in `generate.hyp`. They contain
|
| 145 |
+
`<unk-N>` tokens that the model has copied from the source sequence. In order to
|
| 146 |
+
retrieve the original words, we need the unprocessed source sequences from
|
| 147 |
+
`test.document`.
|
| 148 |
+
|
| 149 |
+
##### 5. Process the generated output
|
| 150 |
+
|
| 151 |
+
Since we skipped too long inputs when producing `generate.hyp`, we also have to
|
| 152 |
+
skip too long sequences now that we read `test.document`.
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
./postprocess.py \
|
| 156 |
+
--source <(awk 'NF<1024' test.document) \
|
| 157 |
+
--target generate.hyp \
|
| 158 |
+
--target-out generate.hyp.processed
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
Now you'll find the final sequences from `generate.hyp.processed`, with
|
| 162 |
+
`<unk-N>` replaced with the original word from the source sequence.
|
| 163 |
+
|
| 164 |
+
##### An example of a summarized sequence
|
| 165 |
+
|
| 166 |
+
The original source document in `test.document`:
|
| 167 |
+
|
| 168 |
+
> de roon moved to teesside in june 2016 for an initial # 8.8 m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
|
| 169 |
+
|
| 170 |
+
The preprocessed source document in `test.src.pg`:
|
| 171 |
+
|
| 172 |
+
> de \<unk-1> moved to \<unk-4> in june 2016 for an initial # \<unk-12> m fee and played 33 premier league games last term . the netherlands international , 26 , scored five goals in 36 league and cup games during his spell at boro . meanwhile , manager garry monk confirmed the championship club 's interest in signing chelsea midfielder lewis baker . `` he 's a target and one of many that we 've had throughout the summer months , '' said monk . find all the latest football transfers on our dedicated page .
|
| 173 |
+
|
| 174 |
+
The generated summary in `generate.hyp`:
|
| 175 |
+
|
| 176 |
+
> middlesbrough striker \<unk> de \<unk-1> has joined spanish side \<unk> on a season-long loan .
|
| 177 |
+
|
| 178 |
+
The generated summary after postprocessing in `generate.hyp.processed`:
|
| 179 |
+
|
| 180 |
+
> middlesbrough striker \<unk> de roon has joined spanish side \<unk> on a season-long loan .
|
fairseq-0.10.2/examples/pointer_generator/pointer_generator_src/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import transformer_pg # noqa
|
fairseq-0.10.2/examples/pointer_generator/pointer_generator_src/transformer_pg.py
ADDED
|
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Any, Dict, Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
from fairseq import metrics, utils
|
| 12 |
+
from fairseq.models import register_model, register_model_architecture
|
| 13 |
+
from fairseq.models.fairseq_encoder import EncoderOut
|
| 14 |
+
from fairseq.models.transformer import (
|
| 15 |
+
DEFAULT_MAX_SOURCE_POSITIONS,
|
| 16 |
+
DEFAULT_MAX_TARGET_POSITIONS,
|
| 17 |
+
TransformerDecoder,
|
| 18 |
+
TransformerEncoder,
|
| 19 |
+
TransformerModel,
|
| 20 |
+
base_architecture,
|
| 21 |
+
)
|
| 22 |
+
from torch import Tensor
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@register_model("transformer_pointer_generator")
|
| 29 |
+
class TransformerPointerGeneratorModel(TransformerModel):
|
| 30 |
+
"""
|
| 31 |
+
Transformer model from `"Attention Is All You Need" (Vaswani et al, 2017)
|
| 32 |
+
<https://arxiv.org/abs/1706.03762>`_, augmented with a pointer-generator
|
| 33 |
+
network from `"Get To The Point: Summarization with Pointer-Generator
|
| 34 |
+
Networks" (See et al, 2017) <https://arxiv.org/abs/1704.04368>`_.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
encoder (TransformerPointerGeneratorEncoder): the encoder
|
| 38 |
+
decoder (TransformerPointerGeneratorDecoder): the decoder
|
| 39 |
+
|
| 40 |
+
The Transformer pointer-generator model provides the following named
|
| 41 |
+
architectures and command-line arguments:
|
| 42 |
+
|
| 43 |
+
.. argparse::
|
| 44 |
+
:ref: fairseq.models.transformer_pointer_generator_parser
|
| 45 |
+
:prog:
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
@staticmethod
|
| 49 |
+
def add_args(parser):
|
| 50 |
+
"""Add model-specific arguments to the parser."""
|
| 51 |
+
# fmt: off
|
| 52 |
+
TransformerModel.add_args(parser)
|
| 53 |
+
parser.add_argument('--alignment-heads', type=int, metavar='N',
|
| 54 |
+
help='number of attention heads to be used for '
|
| 55 |
+
'pointing')
|
| 56 |
+
parser.add_argument('--alignment-layer', type=int, metavar='I',
|
| 57 |
+
help='layer number to be used for pointing (0 '
|
| 58 |
+
'corresponding to the bottommost layer)')
|
| 59 |
+
parser.add_argument('--source-position-markers', type=int, metavar='N',
|
| 60 |
+
help='dictionary includes N additional items that '
|
| 61 |
+
'represent an OOV token at a particular input '
|
| 62 |
+
'position')
|
| 63 |
+
parser.add_argument('--force-generation', type=float, metavar='P',
|
| 64 |
+
default=None,
|
| 65 |
+
help='set the vocabulary distribution weight to P, '
|
| 66 |
+
'instead of predicting it from the input (1.0 '
|
| 67 |
+
'corresponding to generation, 0.0 to pointing)')
|
| 68 |
+
# fmt: on
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def build_model(cls, args, task):
|
| 72 |
+
"""Build a new model instance."""
|
| 73 |
+
|
| 74 |
+
# make sure all arguments are present in older models
|
| 75 |
+
base_architecture(args)
|
| 76 |
+
|
| 77 |
+
if args.encoder_layers_to_keep:
|
| 78 |
+
args.encoder_layers = len(args.encoder_layers_to_keep.split(","))
|
| 79 |
+
if args.decoder_layers_to_keep:
|
| 80 |
+
args.decoder_layers = len(args.decoder_layers_to_keep.split(","))
|
| 81 |
+
|
| 82 |
+
if getattr(args, "max_source_positions", None) is None:
|
| 83 |
+
args.max_source_positions = DEFAULT_MAX_SOURCE_POSITIONS
|
| 84 |
+
if getattr(args, "max_target_positions", None) is None:
|
| 85 |
+
args.max_target_positions = DEFAULT_MAX_TARGET_POSITIONS
|
| 86 |
+
if getattr(args, "source_position_markers", None) is None:
|
| 87 |
+
args.source_position_markers = args.max_source_positions
|
| 88 |
+
|
| 89 |
+
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
|
| 90 |
+
if src_dict != tgt_dict:
|
| 91 |
+
raise ValueError("Pointer-generator requires a joined dictionary")
|
| 92 |
+
|
| 93 |
+
def build_embedding(dictionary, embed_dim, path=None):
|
| 94 |
+
# The dictionary may include additional items that can be used in
|
| 95 |
+
# place of the normal OOV token and that all map to the same
|
| 96 |
+
# embedding. Using a different token for each input position allows
|
| 97 |
+
# one to restore the word identities from the original source text.
|
| 98 |
+
num_embeddings = len(dictionary) - args.source_position_markers
|
| 99 |
+
padding_idx = dictionary.pad()
|
| 100 |
+
unk_idx = dictionary.unk()
|
| 101 |
+
logger.info(
|
| 102 |
+
"dictionary indices from {0} to {1} will be mapped to {2}".format(
|
| 103 |
+
num_embeddings, len(dictionary) - 1, unk_idx
|
| 104 |
+
)
|
| 105 |
+
)
|
| 106 |
+
emb = Embedding(num_embeddings, embed_dim, padding_idx, unk_idx)
|
| 107 |
+
# if provided, load from preloaded dictionaries
|
| 108 |
+
if path:
|
| 109 |
+
embed_dict = utils.parse_embedding(path)
|
| 110 |
+
utils.load_embedding(embed_dict, dictionary, emb)
|
| 111 |
+
return emb
|
| 112 |
+
|
| 113 |
+
if args.share_all_embeddings:
|
| 114 |
+
if args.encoder_embed_dim != args.decoder_embed_dim:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
"--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim"
|
| 117 |
+
)
|
| 118 |
+
if args.decoder_embed_path and (
|
| 119 |
+
args.decoder_embed_path != args.encoder_embed_path
|
| 120 |
+
):
|
| 121 |
+
raise ValueError(
|
| 122 |
+
"--share-all-embeddings not compatible with --decoder-embed-path"
|
| 123 |
+
)
|
| 124 |
+
encoder_embed_tokens = build_embedding(
|
| 125 |
+
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
| 126 |
+
)
|
| 127 |
+
decoder_embed_tokens = encoder_embed_tokens
|
| 128 |
+
args.share_decoder_input_output_embed = True
|
| 129 |
+
else:
|
| 130 |
+
encoder_embed_tokens = build_embedding(
|
| 131 |
+
src_dict, args.encoder_embed_dim, args.encoder_embed_path
|
| 132 |
+
)
|
| 133 |
+
decoder_embed_tokens = build_embedding(
|
| 134 |
+
tgt_dict, args.decoder_embed_dim, args.decoder_embed_path
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
encoder = cls.build_encoder(args, src_dict, encoder_embed_tokens)
|
| 138 |
+
decoder = cls.build_decoder(args, tgt_dict, decoder_embed_tokens)
|
| 139 |
+
return cls(args, encoder, decoder)
|
| 140 |
+
|
| 141 |
+
@classmethod
|
| 142 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 143 |
+
return TransformerPointerGeneratorEncoder(args, src_dict, embed_tokens)
|
| 144 |
+
|
| 145 |
+
@classmethod
|
| 146 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 147 |
+
return TransformerPointerGeneratorDecoder(args, tgt_dict, embed_tokens)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class TransformerPointerGeneratorEncoder(TransformerEncoder):
|
| 151 |
+
"""
|
| 152 |
+
Transformer encoder consisting of *args.encoder_layers* layers. Each layer
|
| 153 |
+
is a :class:`TransformerEncoderLayer`. The pointer-generator variant adds
|
| 154 |
+
the source tokens to the encoder output as these are otherwise not passed
|
| 155 |
+
to the decoder.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
def forward(self, src_tokens, src_lengths, **kwargs):
|
| 159 |
+
"""
|
| 160 |
+
Runs the `forward()` method of the parent Transformer class. Then adds
|
| 161 |
+
the source tokens into the encoder output tuple.
|
| 162 |
+
|
| 163 |
+
While it might be more elegant that the model would pass the source
|
| 164 |
+
tokens to the `forward()` method of the decoder too, this would require
|
| 165 |
+
changes to `SequenceGenerator`.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
src_tokens (torch.LongTensor): tokens in the source language of
|
| 169 |
+
shape `(batch, src_len)`
|
| 170 |
+
src_lengths (torch.LongTensor): lengths of each source sentence of
|
| 171 |
+
shape `(batch)`
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
namedtuple:
|
| 175 |
+
- **encoder_out** (Tensor): the last encoder layer's output of
|
| 176 |
+
shape `(src_len, batch, embed_dim)`
|
| 177 |
+
- **encoder_padding_mask** (ByteTensor): the positions of
|
| 178 |
+
padding elements of shape `(batch, src_len)`
|
| 179 |
+
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
|
| 180 |
+
of shape `(batch, src_len, embed_dim)`
|
| 181 |
+
- **encoder_states** (List[Tensor]): all intermediate
|
| 182 |
+
hidden states of shape `(src_len, batch, embed_dim)`.
|
| 183 |
+
Only populated if *return_all_hiddens* is True.
|
| 184 |
+
- **src_tokens** (Tensor): input token ids of shape
|
| 185 |
+
`(batch, src_len)`
|
| 186 |
+
"""
|
| 187 |
+
encoder_out = super().forward(src_tokens, src_lengths, **kwargs)
|
| 188 |
+
return EncoderOut(
|
| 189 |
+
encoder_out=encoder_out.encoder_out, # T x B x C
|
| 190 |
+
encoder_padding_mask=encoder_out.encoder_padding_mask, # B x T
|
| 191 |
+
encoder_embedding=encoder_out.encoder_embedding, # B x T x C
|
| 192 |
+
encoder_states=encoder_out.encoder_states, # List[T x B x C]
|
| 193 |
+
src_tokens=src_tokens, # B x T
|
| 194 |
+
src_lengths=None,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
class TransformerPointerGeneratorDecoder(TransformerDecoder):
|
| 199 |
+
"""
|
| 200 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
| 201 |
+
is a :class:`TransformerDecoderLayer`. The pointer-generator variant mixes
|
| 202 |
+
the output probabilities with an attention distribution in the output layer.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 206 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
| 207 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
| 208 |
+
"""
|
| 209 |
+
|
| 210 |
+
def __init__(self, args, dictionary, embed_tokens):
|
| 211 |
+
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
|
| 212 |
+
|
| 213 |
+
# In the pointer-generator model these arguments define the decoder
|
| 214 |
+
# layer and the number of attention heads that will be averaged to
|
| 215 |
+
# create the alignment for pointing.
|
| 216 |
+
self.alignment_heads = args.alignment_heads
|
| 217 |
+
self.alignment_layer = args.alignment_layer
|
| 218 |
+
|
| 219 |
+
input_embed_dim = embed_tokens.embedding_dim
|
| 220 |
+
|
| 221 |
+
# Generation probabilities / interpolation coefficients are predicted
|
| 222 |
+
# from the current decoder input embedding and the decoder output, which
|
| 223 |
+
# is the size of output_embed_dim.
|
| 224 |
+
p_gen_input_size = input_embed_dim + self.output_embed_dim
|
| 225 |
+
self.project_p_gens = nn.Linear(p_gen_input_size, 1)
|
| 226 |
+
nn.init.zeros_(self.project_p_gens.bias)
|
| 227 |
+
|
| 228 |
+
# The dictionary may include a separate entry for an OOV token in each
|
| 229 |
+
# input position, so that their identity can be restored from the
|
| 230 |
+
# original source text.
|
| 231 |
+
self.num_types = len(dictionary)
|
| 232 |
+
self.num_oov_types = args.source_position_markers
|
| 233 |
+
self.num_embeddings = self.num_types - self.num_oov_types
|
| 234 |
+
self.force_p_gen = args.force_generation
|
| 235 |
+
|
| 236 |
+
def forward(
|
| 237 |
+
self,
|
| 238 |
+
prev_output_tokens,
|
| 239 |
+
encoder_out: Optional[EncoderOut] = None,
|
| 240 |
+
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None,
|
| 241 |
+
features_only: bool = False,
|
| 242 |
+
alignment_layer: Optional[int] = 0,
|
| 243 |
+
alignment_heads: Optional[int] = 1,
|
| 244 |
+
src_lengths: Optional[Any] = None,
|
| 245 |
+
return_all_hiddens: bool = False,
|
| 246 |
+
):
|
| 247 |
+
"""
|
| 248 |
+
Args:
|
| 249 |
+
prev_output_tokens (LongTensor): previous decoder outputs of shape
|
| 250 |
+
`(batch, tgt_len)`, for teacher forcing
|
| 251 |
+
encoder_out (EncoderOut, optional): output from the encoder, used
|
| 252 |
+
for encoder-side attention
|
| 253 |
+
incremental_state (dict, optional): dictionary used for storing
|
| 254 |
+
state during :ref:`Incremental decoding`
|
| 255 |
+
features_only (bool, optional): only return features without
|
| 256 |
+
applying output layer (default: False)
|
| 257 |
+
alignment_layer (int, optional): 0-based index of the layer to be
|
| 258 |
+
used for pointing (default: 0)
|
| 259 |
+
alignment_heads (int, optional): number of attention heads to be
|
| 260 |
+
used for pointing (default: 1)
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
tuple:
|
| 264 |
+
- the decoder's output of shape `(batch, tgt_len, vocab)`
|
| 265 |
+
- a dictionary with any model-specific outputs
|
| 266 |
+
"""
|
| 267 |
+
# The normal Transformer model doesn't pass the alignment_layer and
|
| 268 |
+
# alignment_heads parameters correctly. We use our local variables.
|
| 269 |
+
x, extra = self.extract_features(
|
| 270 |
+
prev_output_tokens,
|
| 271 |
+
encoder_out=encoder_out,
|
| 272 |
+
incremental_state=incremental_state,
|
| 273 |
+
alignment_layer=self.alignment_layer,
|
| 274 |
+
alignment_heads=self.alignment_heads,
|
| 275 |
+
)
|
| 276 |
+
if not features_only:
|
| 277 |
+
# Embedding the tokens again for generation probability prediction,
|
| 278 |
+
# so that we don't have to reimplement the whole extract_features()
|
| 279 |
+
# method.
|
| 280 |
+
if incremental_state is not None:
|
| 281 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
| 282 |
+
prev_output_embed = self.embed_tokens(prev_output_tokens)
|
| 283 |
+
prev_output_embed *= self.embed_scale
|
| 284 |
+
predictors = torch.cat((prev_output_embed, x), 2)
|
| 285 |
+
p_gens = self.project_p_gens(predictors)
|
| 286 |
+
p_gens = torch.sigmoid(p_gens)
|
| 287 |
+
x = self.output_layer(x, extra["attn"][0], encoder_out.src_tokens, p_gens)
|
| 288 |
+
return x, extra
|
| 289 |
+
|
| 290 |
+
def output_layer(self, features, attn, src_tokens, p_gens, **kwargs):
|
| 291 |
+
"""
|
| 292 |
+
Project features to the vocabulary size and mix with the attention
|
| 293 |
+
distributions.
|
| 294 |
+
"""
|
| 295 |
+
if self.force_p_gen is not None:
|
| 296 |
+
p_gens = self.force_p_gen
|
| 297 |
+
|
| 298 |
+
# project back to size of vocabulary
|
| 299 |
+
logits = super().output_layer(features, **kwargs)
|
| 300 |
+
|
| 301 |
+
batch_size = logits.shape[0]
|
| 302 |
+
output_length = logits.shape[1]
|
| 303 |
+
assert logits.shape[2] == self.num_embeddings
|
| 304 |
+
assert src_tokens.shape[0] == batch_size
|
| 305 |
+
src_length = src_tokens.shape[1]
|
| 306 |
+
|
| 307 |
+
# The final output distribution will be a mixture of the normal output
|
| 308 |
+
# distribution (softmax of logits) and attention weights.
|
| 309 |
+
gen_dists = super().get_normalized_probs(
|
| 310 |
+
(logits, None), log_probs=False, sample=None
|
| 311 |
+
)
|
| 312 |
+
gen_dists = torch.mul(gen_dists, p_gens)
|
| 313 |
+
padding_size = (batch_size, output_length, self.num_oov_types)
|
| 314 |
+
padding = gen_dists.new_zeros(padding_size)
|
| 315 |
+
gen_dists = torch.cat((gen_dists, padding), 2)
|
| 316 |
+
assert gen_dists.shape[2] == self.num_types
|
| 317 |
+
|
| 318 |
+
# Scatter attention distributions to distributions over the extended
|
| 319 |
+
# vocabulary in a tensor of shape [batch_size, output_length,
|
| 320 |
+
# vocab_size]. Each attention weight will be written into a location
|
| 321 |
+
# that is for other dimensions the same as in the index tensor, but for
|
| 322 |
+
# the third dimension it's the value of the index tensor (the token ID).
|
| 323 |
+
attn = torch.mul(attn, 1 - p_gens)
|
| 324 |
+
index = src_tokens[:, None, :]
|
| 325 |
+
index = index.expand(batch_size, output_length, src_length)
|
| 326 |
+
attn_dists_size = (batch_size, output_length, self.num_types)
|
| 327 |
+
attn_dists = attn.new_zeros(attn_dists_size)
|
| 328 |
+
attn_dists.scatter_add_(2, index, attn)
|
| 329 |
+
|
| 330 |
+
# Final distributions, [batch_size, output_length, num_types].
|
| 331 |
+
return gen_dists + attn_dists
|
| 332 |
+
|
| 333 |
+
def get_normalized_probs(self, net_output, log_probs, sample):
|
| 334 |
+
"""
|
| 335 |
+
Get normalized probabilities (or log probs) from a net's output.
|
| 336 |
+
Pointer-generator network output is already normalized.
|
| 337 |
+
"""
|
| 338 |
+
probs = net_output[0]
|
| 339 |
+
# Make sure the probabilities are greater than zero when returning log
|
| 340 |
+
# probabilities.
|
| 341 |
+
return probs.clamp(1e-10, 1.0).log() if log_probs else probs
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class Embedding(nn.Embedding):
|
| 345 |
+
r"""A simple lookup table that stores embeddings of a fixed dictionary and size.
|
| 346 |
+
This module is often used to store word embeddings and retrieve them using indices.
|
| 347 |
+
The input to the module is a list of indices, and the output is the corresponding
|
| 348 |
+
word embeddings. This subclass differs from the standard PyTorch Embedding class by
|
| 349 |
+
allowing additional vocabulary entries that will be mapped to the unknown token
|
| 350 |
+
embedding.
|
| 351 |
+
Args:
|
| 352 |
+
num_embeddings (int): size of the dictionary of embeddings
|
| 353 |
+
embedding_dim (int): the size of each embedding vector
|
| 354 |
+
padding_idx (int): Pads the output with the embedding vector at :attr:`padding_idx`
|
| 355 |
+
(initialized to zeros) whenever it encounters the index.
|
| 356 |
+
unk_idx (int): Maps all token indices that are greater than or equal to
|
| 357 |
+
num_embeddings to this index.
|
| 358 |
+
Attributes:
|
| 359 |
+
weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim)
|
| 360 |
+
initialized from :math:`\mathcal{N}(0, 1)`
|
| 361 |
+
Shape:
|
| 362 |
+
- Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract
|
| 363 |
+
- Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}`
|
| 364 |
+
.. note::
|
| 365 |
+
Keep in mind that only a limited number of optimizers support
|
| 366 |
+
sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`),
|
| 367 |
+
:class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`)
|
| 368 |
+
.. note::
|
| 369 |
+
With :attr:`padding_idx` set, the embedding vector at
|
| 370 |
+
:attr:`padding_idx` is initialized to all zeros. However, note that this
|
| 371 |
+
vector can be modified afterwards, e.g., using a customized
|
| 372 |
+
initialization method, and thus changing the vector used to pad the
|
| 373 |
+
output. The gradient for this vector from :class:`~torch.nn.Embedding`
|
| 374 |
+
is always zero.
|
| 375 |
+
"""
|
| 376 |
+
__constants__ = ["unk_idx"]
|
| 377 |
+
|
| 378 |
+
def __init__(self, num_embeddings, embedding_dim, padding_idx, unk_idx):
|
| 379 |
+
super().__init__(num_embeddings, embedding_dim, padding_idx=padding_idx)
|
| 380 |
+
self.unk_idx = unk_idx
|
| 381 |
+
nn.init.normal_(self.weight, mean=0, std=embedding_dim ** -0.5)
|
| 382 |
+
nn.init.constant_(self.weight[padding_idx], 0)
|
| 383 |
+
|
| 384 |
+
def forward(self, input):
|
| 385 |
+
input = torch.where(
|
| 386 |
+
input >= self.num_embeddings, torch.ones_like(input) * self.unk_idx, input
|
| 387 |
+
)
|
| 388 |
+
return super().forward(input)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
@register_model_architecture(
|
| 392 |
+
"transformer_pointer_generator", "transformer_pointer_generator"
|
| 393 |
+
)
|
| 394 |
+
def transformer_pointer_generator(args):
|
| 395 |
+
args.alignment_heads = getattr(args, "alignment_heads", 1)
|
| 396 |
+
args.alignment_layer = getattr(args, "alignment_layer", -1)
|
| 397 |
+
base_architecture(args)
|
| 398 |
+
if args.alignment_layer < 0:
|
| 399 |
+
args.alignment_layer = args.decoder_layers + args.alignment_layer
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@register_model_architecture(
|
| 403 |
+
"transformer_pointer_generator", "transformer_pointer_generator_iwslt_de_en"
|
| 404 |
+
)
|
| 405 |
+
def transformer_pointer_generator_iwslt_de_en(args):
|
| 406 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512)
|
| 407 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 1024)
|
| 408 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 4)
|
| 409 |
+
args.encoder_layers = getattr(args, "encoder_layers", 6)
|
| 410 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 512)
|
| 411 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 1024)
|
| 412 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 4)
|
| 413 |
+
args.decoder_layers = getattr(args, "decoder_layers", 6)
|
| 414 |
+
transformer_pointer_generator(args)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
@register_model_architecture(
|
| 418 |
+
"transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de"
|
| 419 |
+
)
|
| 420 |
+
def transformer_pointer_generator_wmt_en_de(args):
|
| 421 |
+
transformer_pointer_generator(args)
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
# Transformer pointer-generator with the base Transformer parameters as used in
|
| 425 |
+
# the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
| 426 |
+
@register_model_architecture(
|
| 427 |
+
"transformer_pointer_generator",
|
| 428 |
+
"transformer_pointer_generator_vaswani_wmt_en_de_big",
|
| 429 |
+
)
|
| 430 |
+
def transformer_pointer_generator_vaswani_wmt_en_de_big(args):
|
| 431 |
+
args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1024)
|
| 432 |
+
args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 4096)
|
| 433 |
+
args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 16)
|
| 434 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", False)
|
| 435 |
+
args.decoder_embed_dim = getattr(args, "decoder_embed_dim", 1024)
|
| 436 |
+
args.decoder_ffn_embed_dim = getattr(args, "decoder_ffn_embed_dim", 4096)
|
| 437 |
+
args.decoder_attention_heads = getattr(args, "decoder_attention_heads", 16)
|
| 438 |
+
args.dropout = getattr(args, "dropout", 0.3)
|
| 439 |
+
transformer_pointer_generator(args)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@register_model_architecture(
|
| 443 |
+
"transformer_pointer_generator",
|
| 444 |
+
"transformer_pointer_generator_vaswani_wmt_en_fr_big",
|
| 445 |
+
)
|
| 446 |
+
def transformer_pointer_generator_vaswani_wmt_en_fr_big(args):
|
| 447 |
+
args.dropout = getattr(args, "dropout", 0.1)
|
| 448 |
+
transformer_pointer_generator_vaswani_wmt_en_de_big(args)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
@register_model_architecture(
|
| 452 |
+
"transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de_big"
|
| 453 |
+
)
|
| 454 |
+
def transformer_pointer_generator_wmt_en_de_big(args):
|
| 455 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 456 |
+
transformer_pointer_generator_vaswani_wmt_en_de_big(args)
|
| 457 |
+
|
| 458 |
+
|
| 459 |
+
# default parameters used in tensor2tensor implementation
|
| 460 |
+
@register_model_architecture(
|
| 461 |
+
"transformer_pointer_generator", "transformer_pointer_generator_wmt_en_de_big_t2t"
|
| 462 |
+
)
|
| 463 |
+
def transformer_pointer_generator_wmt_en_de_big_t2t(args):
|
| 464 |
+
args.encoder_normalize_before = getattr(args, "encoder_normalize_before", True)
|
| 465 |
+
args.decoder_normalize_before = getattr(args, "decoder_normalize_before", True)
|
| 466 |
+
args.attention_dropout = getattr(args, "attention_dropout", 0.1)
|
| 467 |
+
args.activation_dropout = getattr(args, "activation_dropout", 0.1)
|
| 468 |
+
transformer_pointer_generator_vaswani_wmt_en_de_big(args)
|
fairseq-0.10.2/examples/pointer_generator/postprocess.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class OOVIndexError(IndexError):
|
| 13 |
+
def __init__(self, pos, source_seq, target_seq):
|
| 14 |
+
super(OOVIndexError, self).__init__(
|
| 15 |
+
"A <unk-N> tag in the target sequence refers to a position that is "
|
| 16 |
+
"outside the source sequence. Most likely there was a mismatch in "
|
| 17 |
+
"provided source and target sequences. Otherwise this would mean that "
|
| 18 |
+
"the pointing mechanism somehow attended to a position that is past "
|
| 19 |
+
"the actual sequence end."
|
| 20 |
+
)
|
| 21 |
+
self.source_pos = pos
|
| 22 |
+
self.source_seq = source_seq
|
| 23 |
+
self.target_seq = target_seq
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def replace_oovs(source_in, target_in, target_out):
|
| 27 |
+
"""Replaces <unk-N> tokens in the target text with the corresponding word in
|
| 28 |
+
the source text.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
oov_re = re.compile("^<unk-([0-9]+)>$")
|
| 32 |
+
|
| 33 |
+
for source_seq, target_seq in zip(source_in, target_in):
|
| 34 |
+
target_seq_out = []
|
| 35 |
+
|
| 36 |
+
pos_to_word = source_seq.strip().split()
|
| 37 |
+
for token in target_seq.strip().split():
|
| 38 |
+
m = oov_re.match(token)
|
| 39 |
+
if m:
|
| 40 |
+
pos = int(m.group(1))
|
| 41 |
+
if pos >= len(pos_to_word):
|
| 42 |
+
raise OOVIndexError(pos, source_seq, target_seq)
|
| 43 |
+
token_out = pos_to_word[pos]
|
| 44 |
+
else:
|
| 45 |
+
token_out = token
|
| 46 |
+
target_seq_out.append(token_out)
|
| 47 |
+
target_out.write(" ".join(target_seq_out) + "\n")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def main():
|
| 51 |
+
parser = argparse.ArgumentParser(
|
| 52 |
+
description="Replaces <unk-N> tokens in target sequences with words from "
|
| 53 |
+
"the corresponding position in the source sequence."
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--source", type=str, help="text file with source sequences", required=True
|
| 57 |
+
)
|
| 58 |
+
parser.add_argument(
|
| 59 |
+
"--target", type=str, help="text file with target sequences", required=True
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--target-out",
|
| 63 |
+
type=str,
|
| 64 |
+
help="where to write target sequences without <unk-N> " "entries",
|
| 65 |
+
required=True,
|
| 66 |
+
)
|
| 67 |
+
args = parser.parse_args()
|
| 68 |
+
|
| 69 |
+
target_in = (
|
| 70 |
+
open(args.target, "r", encoding="utf-8") if args.target is not None else None
|
| 71 |
+
)
|
| 72 |
+
target_out = (
|
| 73 |
+
open(args.target_out, "w", encoding="utf-8")
|
| 74 |
+
if args.target_out is not None
|
| 75 |
+
else None
|
| 76 |
+
)
|
| 77 |
+
with open(args.source, "r", encoding="utf-8") as source_in, open(
|
| 78 |
+
args.target, "r", encoding="utf-8"
|
| 79 |
+
) as target_in, open(args.target_out, "w", encoding="utf-8") as target_out:
|
| 80 |
+
replace_oovs(source_in, target_in, target_out)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
try:
|
| 85 |
+
main()
|
| 86 |
+
except OOVIndexError as e:
|
| 87 |
+
print(e, file=sys.stderr)
|
| 88 |
+
print("Source sequence:", e.source_seq.strip(), file=sys.stderr)
|
| 89 |
+
print("Target sequence:", e.target_seq.strip(), file=sys.stderr)
|
| 90 |
+
print(
|
| 91 |
+
"Source sequence length:",
|
| 92 |
+
len(e.source_seq.strip().split()),
|
| 93 |
+
file=sys.stderr,
|
| 94 |
+
)
|
| 95 |
+
print("The offending tag points to:", e.source_pos)
|
| 96 |
+
sys.exit(2)
|
fairseq-0.10.2/examples/pointer_generator/preprocess.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import argparse
|
| 8 |
+
from itertools import zip_longest
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def replace_oovs(source_in, target_in, vocabulary, source_out, target_out):
|
| 12 |
+
"""Replaces out-of-vocabulary words in source and target text with <unk-N>,
|
| 13 |
+
where N in is the position of the word in the source sequence.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
def format_unk(pos):
|
| 17 |
+
return "<unk-{}>".format(pos)
|
| 18 |
+
|
| 19 |
+
if target_in is None:
|
| 20 |
+
target_in = []
|
| 21 |
+
|
| 22 |
+
for seq_num, (source_seq, target_seq) in enumerate(
|
| 23 |
+
zip_longest(source_in, target_in)
|
| 24 |
+
):
|
| 25 |
+
source_seq_out = []
|
| 26 |
+
target_seq_out = []
|
| 27 |
+
|
| 28 |
+
word_to_pos = dict()
|
| 29 |
+
for position, token in enumerate(source_seq.strip().split()):
|
| 30 |
+
if token in vocabulary:
|
| 31 |
+
token_out = token
|
| 32 |
+
else:
|
| 33 |
+
if token in word_to_pos:
|
| 34 |
+
oov_pos = word_to_pos[token]
|
| 35 |
+
else:
|
| 36 |
+
word_to_pos[token] = position
|
| 37 |
+
oov_pos = position
|
| 38 |
+
token_out = format_unk(oov_pos)
|
| 39 |
+
source_seq_out.append(token_out)
|
| 40 |
+
source_out.write(" ".join(source_seq_out) + "\n")
|
| 41 |
+
|
| 42 |
+
if target_seq is not None:
|
| 43 |
+
for token in target_seq.strip().split():
|
| 44 |
+
if token in word_to_pos:
|
| 45 |
+
token_out = format_unk(word_to_pos[token])
|
| 46 |
+
else:
|
| 47 |
+
token_out = token
|
| 48 |
+
target_seq_out.append(token_out)
|
| 49 |
+
if target_out is not None:
|
| 50 |
+
target_out.write(" ".join(target_seq_out) + "\n")
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def main():
|
| 54 |
+
parser = argparse.ArgumentParser(
|
| 55 |
+
description="Replaces out-of-vocabulary words in both source and target "
|
| 56 |
+
"sequences with tokens that indicate the position of the word "
|
| 57 |
+
"in the source sequence."
|
| 58 |
+
)
|
| 59 |
+
parser.add_argument(
|
| 60 |
+
"--source", type=str, help="text file with source sequences", required=True
|
| 61 |
+
)
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--target", type=str, help="text file with target sequences", default=None
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument("--vocab", type=str, help="vocabulary file", required=True)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--source-out",
|
| 68 |
+
type=str,
|
| 69 |
+
help="where to write source sequences with <unk-N> entries",
|
| 70 |
+
required=True,
|
| 71 |
+
)
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--target-out",
|
| 74 |
+
type=str,
|
| 75 |
+
help="where to write target sequences with <unk-N> entries",
|
| 76 |
+
default=None,
|
| 77 |
+
)
|
| 78 |
+
args = parser.parse_args()
|
| 79 |
+
|
| 80 |
+
with open(args.vocab, encoding="utf-8") as vocab:
|
| 81 |
+
vocabulary = vocab.read().splitlines()
|
| 82 |
+
|
| 83 |
+
target_in = (
|
| 84 |
+
open(args.target, "r", encoding="utf-8") if args.target is not None else None
|
| 85 |
+
)
|
| 86 |
+
target_out = (
|
| 87 |
+
open(args.target_out, "w", encoding="utf-8")
|
| 88 |
+
if args.target_out is not None
|
| 89 |
+
else None
|
| 90 |
+
)
|
| 91 |
+
with open(args.source, "r", encoding="utf-8") as source_in, open(
|
| 92 |
+
args.source_out, "w", encoding="utf-8"
|
| 93 |
+
) as source_out:
|
| 94 |
+
replace_oovs(source_in, target_in, vocabulary, source_out, target_out)
|
| 95 |
+
if target_in is not None:
|
| 96 |
+
target_in.close()
|
| 97 |
+
if target_out is not None:
|
| 98 |
+
target_out.close()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
fairseq-0.10.2/examples/roberta/README.custom_classification.md
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finetuning RoBERTa on a custom classification task
|
| 2 |
+
|
| 3 |
+
This example shows how to finetune RoBERTa on the IMDB dataset, but should illustrate the process for most classification tasks.
|
| 4 |
+
|
| 5 |
+
### 1) Get the data
|
| 6 |
+
|
| 7 |
+
```bash
|
| 8 |
+
wget http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz
|
| 9 |
+
tar zxvf aclImdb_v1.tar.gz
|
| 10 |
+
```
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
### 2) Format data
|
| 14 |
+
|
| 15 |
+
`IMDB` data has one data-sample in each file, below python code-snippet converts it one file for train and valid each for ease of processing.
|
| 16 |
+
```python
|
| 17 |
+
import argparse
|
| 18 |
+
import os
|
| 19 |
+
import random
|
| 20 |
+
from glob import glob
|
| 21 |
+
|
| 22 |
+
random.seed(0)
|
| 23 |
+
|
| 24 |
+
def main(args):
|
| 25 |
+
for split in ['train', 'test']:
|
| 26 |
+
samples = []
|
| 27 |
+
for class_label in ['pos', 'neg']:
|
| 28 |
+
fnames = glob(os.path.join(args.datadir, split, class_label) + '/*.txt')
|
| 29 |
+
for fname in fnames:
|
| 30 |
+
with open(fname) as fin:
|
| 31 |
+
line = fin.readline()
|
| 32 |
+
samples.append((line, 1 if class_label == 'pos' else 0))
|
| 33 |
+
random.shuffle(samples)
|
| 34 |
+
out_fname = 'train' if split == 'train' else 'dev'
|
| 35 |
+
f1 = open(os.path.join(args.datadir, out_fname + '.input0'), 'w')
|
| 36 |
+
f2 = open(os.path.join(args.datadir, out_fname + '.label'), 'w')
|
| 37 |
+
for sample in samples:
|
| 38 |
+
f1.write(sample[0] + '\n')
|
| 39 |
+
f2.write(str(sample[1]) + '\n')
|
| 40 |
+
f1.close()
|
| 41 |
+
f2.close()
|
| 42 |
+
|
| 43 |
+
if __name__ == '__main__':
|
| 44 |
+
parser = argparse.ArgumentParser()
|
| 45 |
+
parser.add_argument('--datadir', default='aclImdb')
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
main(args)
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
### 3) BPE encode
|
| 52 |
+
|
| 53 |
+
Run `multiprocessing_bpe_encoder`, you can also do this in previous step for each sample but that might be slower.
|
| 54 |
+
```bash
|
| 55 |
+
# Download encoder.json and vocab.bpe
|
| 56 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json'
|
| 57 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe'
|
| 58 |
+
|
| 59 |
+
for SPLIT in train dev; do
|
| 60 |
+
python -m examples.roberta.multiprocessing_bpe_encoder \
|
| 61 |
+
--encoder-json encoder.json \
|
| 62 |
+
--vocab-bpe vocab.bpe \
|
| 63 |
+
--inputs "aclImdb/$SPLIT.input0" \
|
| 64 |
+
--outputs "aclImdb/$SPLIT.input0.bpe" \
|
| 65 |
+
--workers 60 \
|
| 66 |
+
--keep-empty
|
| 67 |
+
done
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
### 4) Preprocess data
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
# Download fairseq dictionary.
|
| 75 |
+
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt'
|
| 76 |
+
|
| 77 |
+
fairseq-preprocess \
|
| 78 |
+
--only-source \
|
| 79 |
+
--trainpref "aclImdb/train.input0.bpe" \
|
| 80 |
+
--validpref "aclImdb/dev.input0.bpe" \
|
| 81 |
+
--destdir "IMDB-bin/input0" \
|
| 82 |
+
--workers 60 \
|
| 83 |
+
--srcdict dict.txt
|
| 84 |
+
|
| 85 |
+
fairseq-preprocess \
|
| 86 |
+
--only-source \
|
| 87 |
+
--trainpref "aclImdb/train.label" \
|
| 88 |
+
--validpref "aclImdb/dev.label" \
|
| 89 |
+
--destdir "IMDB-bin/label" \
|
| 90 |
+
--workers 60
|
| 91 |
+
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
### 5) Run training
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
TOTAL_NUM_UPDATES=7812 # 10 epochs through IMDB for bsz 32
|
| 99 |
+
WARMUP_UPDATES=469 # 6 percent of the number of updates
|
| 100 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
| 101 |
+
HEAD_NAME=imdb_head # Custom name for the classification head.
|
| 102 |
+
NUM_CLASSES=2 # Number of classes for the classification task.
|
| 103 |
+
MAX_SENTENCES=8 # Batch size.
|
| 104 |
+
ROBERTA_PATH=/path/to/roberta.large/model.pt
|
| 105 |
+
|
| 106 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train IMDB-bin/ \
|
| 107 |
+
--restore-file $ROBERTA_PATH \
|
| 108 |
+
--max-positions 512 \
|
| 109 |
+
--batch-size $MAX_SENTENCES \
|
| 110 |
+
--max-tokens 4400 \
|
| 111 |
+
--task sentence_prediction \
|
| 112 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 113 |
+
--required-batch-size-multiple 1 \
|
| 114 |
+
--init-token 0 --separator-token 2 \
|
| 115 |
+
--arch roberta_large \
|
| 116 |
+
--criterion sentence_prediction \
|
| 117 |
+
--classification-head-name $HEAD_NAME \
|
| 118 |
+
--num-classes $NUM_CLASSES \
|
| 119 |
+
--dropout 0.1 --attention-dropout 0.1 \
|
| 120 |
+
--weight-decay 0.1 --optimizer adam --adam-betas "(0.9, 0.98)" --adam-eps 1e-06 \
|
| 121 |
+
--clip-norm 0.0 \
|
| 122 |
+
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \
|
| 123 |
+
--fp16 --fp16-init-scale 4 --threshold-loss-scale 1 --fp16-scale-window 128 \
|
| 124 |
+
--max-epoch 10 \
|
| 125 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
| 126 |
+
--shorten-method "truncate" \
|
| 127 |
+
--find-unused-parameters \
|
| 128 |
+
--update-freq 4
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
The above command will finetune RoBERTa-large with an effective batch-size of 32
|
| 132 |
+
sentences (`--batch-size=8 --update-freq=4`). The expected
|
| 133 |
+
`best-validation-accuracy` after 10 epochs is ~96.5%.
|
| 134 |
+
|
| 135 |
+
If you run out of GPU memory, try decreasing `--batch-size` and increase
|
| 136 |
+
`--update-freq` to compensate.
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
### 6) Load model using hub interface
|
| 140 |
+
|
| 141 |
+
Now we can load the trained model checkpoint using the RoBERTa hub interface.
|
| 142 |
+
|
| 143 |
+
Assuming your checkpoints are stored in `checkpoints/`:
|
| 144 |
+
```python
|
| 145 |
+
from fairseq.models.roberta import RobertaModel
|
| 146 |
+
roberta = RobertaModel.from_pretrained(
|
| 147 |
+
'checkpoints',
|
| 148 |
+
checkpoint_file='checkpoint_best.pt',
|
| 149 |
+
data_name_or_path='IMDB-bin'
|
| 150 |
+
)
|
| 151 |
+
roberta.eval() # disable dropout
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
Finally you can make predictions using the `imdb_head` (or whatever you set
|
| 155 |
+
`--classification-head-name` to during training):
|
| 156 |
+
```python
|
| 157 |
+
label_fn = lambda label: roberta.task.label_dictionary.string(
|
| 158 |
+
[label + roberta.task.label_dictionary.nspecial]
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
tokens = roberta.encode('Best movie this year')
|
| 162 |
+
pred = label_fn(roberta.predict('imdb_head', tokens).argmax().item())
|
| 163 |
+
assert pred == '1' # positive
|
| 164 |
+
|
| 165 |
+
tokens = roberta.encode('Worst movie ever')
|
| 166 |
+
pred = label_fn(roberta.predict('imdb_head', tokens).argmax().item())
|
| 167 |
+
assert pred == '0' # negative
|
| 168 |
+
```
|
fairseq-0.10.2/examples/roberta/commonsense_qa/README.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finetuning RoBERTa on Commonsense QA
|
| 2 |
+
|
| 3 |
+
We follow a similar approach to [finetuning RACE](../README.race.md). Specifically
|
| 4 |
+
for each question we construct five inputs, one for each of the five candidate
|
| 5 |
+
answer choices. Each input is constructed by concatenating the question and
|
| 6 |
+
candidate answer. We then encode each input and pass the resulting "[CLS]"
|
| 7 |
+
representations through a fully-connected layer to predict the correct answer.
|
| 8 |
+
We train with a standard cross-entropy loss.
|
| 9 |
+
|
| 10 |
+
We also found it helpful to prepend a prefix of `Q:` to the question and `A:` to
|
| 11 |
+
the answer. The complete input format is:
|
| 12 |
+
```
|
| 13 |
+
<s> Q: Where would I not want a fox? </s> A: hen house </s>
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
Our final submission is based on a hyperparameter search over the learning rate
|
| 17 |
+
(1e-5, 2e-5, 3e-5), batch size (8, 16), number of training steps (2000, 3000,
|
| 18 |
+
4000) and random seed. We selected the model with the best performance on the
|
| 19 |
+
development set after 100 trials.
|
| 20 |
+
|
| 21 |
+
### 1) Download data from the Commonsense QA website (https://www.tau-nlp.org/commonsenseqa)
|
| 22 |
+
```bash
|
| 23 |
+
bash examples/roberta/commonsense_qa/download_cqa_data.sh
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
### 2) Finetune
|
| 27 |
+
|
| 28 |
+
```bash
|
| 29 |
+
MAX_UPDATES=3000 # Number of training steps.
|
| 30 |
+
WARMUP_UPDATES=150 # Linearly increase LR over this many steps.
|
| 31 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
| 32 |
+
MAX_SENTENCES=16 # Batch size.
|
| 33 |
+
SEED=1 # Random seed.
|
| 34 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
| 35 |
+
DATA_DIR=data/CommonsenseQA
|
| 36 |
+
|
| 37 |
+
# we use the --user-dir option to load the task from
|
| 38 |
+
# the examples/roberta/commonsense_qa directory:
|
| 39 |
+
FAIRSEQ_PATH=/path/to/fairseq
|
| 40 |
+
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/commonsense_qa
|
| 41 |
+
|
| 42 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train --fp16 --ddp-backend=no_c10d \
|
| 43 |
+
$DATA_DIR \
|
| 44 |
+
--user-dir $FAIRSEQ_USER_DIR \
|
| 45 |
+
--restore-file $ROBERTA_PATH \
|
| 46 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 47 |
+
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
|
| 48 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
| 49 |
+
--task commonsense_qa --init-token 0 --bpe gpt2 \
|
| 50 |
+
--arch roberta_large --max-positions 512 \
|
| 51 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
| 52 |
+
--criterion sentence_ranking --num-classes 5 \
|
| 53 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 --clip-norm 0.0 \
|
| 54 |
+
--lr-scheduler polynomial_decay --lr $LR \
|
| 55 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $MAX_UPDATES \
|
| 56 |
+
--batch-size $MAX_SENTENCES \
|
| 57 |
+
--max-update $MAX_UPDATES \
|
| 58 |
+
--log-format simple --log-interval 25 \
|
| 59 |
+
--seed $SEED
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
The above command assumes training on 1 GPU with 32GB of RAM. For GPUs with
|
| 63 |
+
less memory, decrease `--batch-size` and increase `--update-freq`
|
| 64 |
+
accordingly to compensate.
|
| 65 |
+
|
| 66 |
+
### 3) Evaluate
|
| 67 |
+
```python
|
| 68 |
+
import json
|
| 69 |
+
import torch
|
| 70 |
+
from fairseq.models.roberta import RobertaModel
|
| 71 |
+
from examples.roberta import commonsense_qa # load the Commonsense QA task
|
| 72 |
+
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'data/CommonsenseQA')
|
| 73 |
+
roberta.eval() # disable dropout
|
| 74 |
+
roberta.cuda() # use the GPU (optional)
|
| 75 |
+
nsamples, ncorrect = 0, 0
|
| 76 |
+
with open('data/CommonsenseQA/valid.jsonl') as h:
|
| 77 |
+
for line in h:
|
| 78 |
+
example = json.loads(line)
|
| 79 |
+
scores = []
|
| 80 |
+
for choice in example['question']['choices']:
|
| 81 |
+
input = roberta.encode(
|
| 82 |
+
'Q: ' + example['question']['stem'],
|
| 83 |
+
'A: ' + choice['text'],
|
| 84 |
+
no_separator=True
|
| 85 |
+
)
|
| 86 |
+
score = roberta.predict('sentence_classification_head', input, return_logits=True)
|
| 87 |
+
scores.append(score)
|
| 88 |
+
pred = torch.cat(scores).argmax()
|
| 89 |
+
answer = ord(example['answerKey']) - ord('A')
|
| 90 |
+
nsamples += 1
|
| 91 |
+
if pred == answer:
|
| 92 |
+
ncorrect += 1
|
| 93 |
+
|
| 94 |
+
print('Accuracy: ' + str(ncorrect / float(nsamples)))
|
| 95 |
+
# Accuracy: 0.7846027846027847
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
The above snippet is not batched, which makes it quite slow. See [instructions
|
| 99 |
+
for batched prediction with RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta#batched-prediction).
|
fairseq-0.10.2/examples/roberta/commonsense_qa/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import commonsense_qa_task # noqa
|
fairseq-0.10.2/examples/roberta/commonsense_qa/commonsense_qa_task.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
from fairseq.data import (
|
| 12 |
+
Dictionary,
|
| 13 |
+
IdDataset,
|
| 14 |
+
ListDataset,
|
| 15 |
+
NestedDictionaryDataset,
|
| 16 |
+
NumelDataset,
|
| 17 |
+
NumSamplesDataset,
|
| 18 |
+
RawLabelDataset,
|
| 19 |
+
RightPadDataset,
|
| 20 |
+
SortDataset,
|
| 21 |
+
data_utils,
|
| 22 |
+
encoders,
|
| 23 |
+
)
|
| 24 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@register_task("commonsense_qa")
|
| 28 |
+
class CommonsenseQATask(LegacyFairseqTask):
|
| 29 |
+
"""Task to finetune RoBERTa for Commonsense QA."""
|
| 30 |
+
|
| 31 |
+
@staticmethod
|
| 32 |
+
def add_args(parser):
|
| 33 |
+
"""Add task-specific arguments to the parser."""
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
| 36 |
+
)
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--init-token",
|
| 39 |
+
type=int,
|
| 40 |
+
default=None,
|
| 41 |
+
help="add token at the beginning of each batch item",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument("--num-classes", type=int, default=5)
|
| 44 |
+
|
| 45 |
+
def __init__(self, args, vocab):
|
| 46 |
+
super().__init__(args)
|
| 47 |
+
self.vocab = vocab
|
| 48 |
+
self.mask = vocab.add_symbol("<mask>")
|
| 49 |
+
|
| 50 |
+
self.bpe = encoders.build_bpe(args)
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def load_dictionary(cls, filename):
|
| 54 |
+
"""Load the dictionary from the filename
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
filename (str): the filename
|
| 58 |
+
"""
|
| 59 |
+
dictionary = Dictionary.load(filename)
|
| 60 |
+
dictionary.add_symbol("<mask>")
|
| 61 |
+
return dictionary
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def setup_task(cls, args, **kwargs):
|
| 65 |
+
assert (
|
| 66 |
+
args.criterion == "sentence_ranking"
|
| 67 |
+
), "Must set --criterion=sentence_ranking"
|
| 68 |
+
|
| 69 |
+
# load data and label dictionaries
|
| 70 |
+
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
| 71 |
+
print("| dictionary: {} types".format(len(vocab)))
|
| 72 |
+
|
| 73 |
+
return cls(args, vocab)
|
| 74 |
+
|
| 75 |
+
def load_dataset(
|
| 76 |
+
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
| 77 |
+
):
|
| 78 |
+
"""Load a given dataset split.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
def binarize(s, append_bos=False):
|
| 85 |
+
if self.bpe is not None:
|
| 86 |
+
s = self.bpe.encode(s)
|
| 87 |
+
tokens = self.vocab.encode_line(
|
| 88 |
+
s,
|
| 89 |
+
append_eos=True,
|
| 90 |
+
add_if_not_exist=False,
|
| 91 |
+
).long()
|
| 92 |
+
if append_bos and self.args.init_token is not None:
|
| 93 |
+
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
| 94 |
+
return tokens
|
| 95 |
+
|
| 96 |
+
if data_path is None:
|
| 97 |
+
data_path = os.path.join(self.args.data, split + ".jsonl")
|
| 98 |
+
if not os.path.exists(data_path):
|
| 99 |
+
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
| 100 |
+
|
| 101 |
+
src_tokens = [[] for i in range(self.args.num_classes)]
|
| 102 |
+
src_lengths = [[] for i in range(self.args.num_classes)]
|
| 103 |
+
labels = []
|
| 104 |
+
|
| 105 |
+
with open(data_path) as h:
|
| 106 |
+
for line in h:
|
| 107 |
+
example = json.loads(line.strip())
|
| 108 |
+
if "answerKey" in example:
|
| 109 |
+
label = ord(example["answerKey"]) - ord("A")
|
| 110 |
+
labels.append(label)
|
| 111 |
+
question = example["question"]["stem"]
|
| 112 |
+
assert len(example["question"]["choices"]) == self.args.num_classes
|
| 113 |
+
# format: `<s> Q: Where would I not want a fox? </s> A: hen house </s>`
|
| 114 |
+
question = "Q: " + question
|
| 115 |
+
question_toks = binarize(question, append_bos=True)
|
| 116 |
+
for i, choice in enumerate(example["question"]["choices"]):
|
| 117 |
+
src = "A: " + choice["text"]
|
| 118 |
+
src_bin = torch.cat([question_toks, binarize(src)])
|
| 119 |
+
src_tokens[i].append(src_bin)
|
| 120 |
+
src_lengths[i].append(len(src_bin))
|
| 121 |
+
assert all(
|
| 122 |
+
len(src_tokens[0]) == len(src_tokens[i])
|
| 123 |
+
for i in range(self.args.num_classes)
|
| 124 |
+
)
|
| 125 |
+
assert len(src_tokens[0]) == len(src_lengths[0])
|
| 126 |
+
assert len(labels) == 0 or len(labels) == len(src_tokens[0])
|
| 127 |
+
|
| 128 |
+
for i in range(self.args.num_classes):
|
| 129 |
+
src_lengths[i] = np.array(src_lengths[i])
|
| 130 |
+
src_tokens[i] = ListDataset(src_tokens[i], src_lengths[i])
|
| 131 |
+
src_lengths[i] = ListDataset(src_lengths[i])
|
| 132 |
+
|
| 133 |
+
dataset = {
|
| 134 |
+
"id": IdDataset(),
|
| 135 |
+
"nsentences": NumSamplesDataset(),
|
| 136 |
+
"ntokens": NumelDataset(src_tokens[0], reduce=True),
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
for i in range(self.args.num_classes):
|
| 140 |
+
dataset.update(
|
| 141 |
+
{
|
| 142 |
+
"net_input{}".format(i + 1): {
|
| 143 |
+
"src_tokens": RightPadDataset(
|
| 144 |
+
src_tokens[i],
|
| 145 |
+
pad_idx=self.source_dictionary.pad(),
|
| 146 |
+
),
|
| 147 |
+
"src_lengths": src_lengths[i],
|
| 148 |
+
}
|
| 149 |
+
}
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
if len(labels) > 0:
|
| 153 |
+
dataset.update({"target": RawLabelDataset(labels)})
|
| 154 |
+
|
| 155 |
+
dataset = NestedDictionaryDataset(
|
| 156 |
+
dataset,
|
| 157 |
+
sizes=[np.maximum.reduce([src_token.sizes for src_token in src_tokens])],
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
with data_utils.numpy_seed(self.args.seed):
|
| 161 |
+
dataset = SortDataset(
|
| 162 |
+
dataset,
|
| 163 |
+
# shuffle
|
| 164 |
+
sort_order=[np.random.permutation(len(dataset))],
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
print("| Loaded {} with {} samples".format(split, len(dataset)))
|
| 168 |
+
|
| 169 |
+
self.datasets[split] = dataset
|
| 170 |
+
return self.datasets[split]
|
| 171 |
+
|
| 172 |
+
def build_model(self, args):
|
| 173 |
+
from fairseq import models
|
| 174 |
+
|
| 175 |
+
model = models.build_model(args, self)
|
| 176 |
+
|
| 177 |
+
model.register_classification_head(
|
| 178 |
+
"sentence_classification_head",
|
| 179 |
+
num_classes=1,
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
return model
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def source_dictionary(self):
|
| 186 |
+
return self.vocab
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def target_dictionary(self):
|
| 190 |
+
return self.vocab
|
fairseq-0.10.2/examples/roberta/commonsense_qa/download_cqa_data.sh
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
OUTDIR=data/CommonsenseQA
|
| 8 |
+
|
| 9 |
+
mkdir -p $OUTDIR
|
| 10 |
+
|
| 11 |
+
wget -O $OUTDIR/train.jsonl https://s3.amazonaws.com/commensenseqa/train_rand_split.jsonl
|
| 12 |
+
wget -O $OUTDIR/valid.jsonl https://s3.amazonaws.com/commensenseqa/dev_rand_split.jsonl
|
| 13 |
+
wget -O $OUTDIR/test.jsonl https://s3.amazonaws.com/commensenseqa/test_rand_split_no_answers.jsonl
|
| 14 |
+
wget -O $OUTDIR/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
fairseq-0.10.2/examples/roberta/preprocess_RACE.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class InputExample:
|
| 15 |
+
def __init__(self, paragraph, qa_list, label):
|
| 16 |
+
self.paragraph = paragraph
|
| 17 |
+
self.qa_list = qa_list
|
| 18 |
+
self.label = label
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_examples(data_dir, set_type):
|
| 22 |
+
"""
|
| 23 |
+
Extract paragraph and question-answer list from each json file
|
| 24 |
+
"""
|
| 25 |
+
examples = []
|
| 26 |
+
|
| 27 |
+
levels = ["middle", "high"]
|
| 28 |
+
set_type_c = set_type.split("-")
|
| 29 |
+
if len(set_type_c) == 2:
|
| 30 |
+
levels = [set_type_c[1]]
|
| 31 |
+
set_type = set_type_c[0]
|
| 32 |
+
for level in levels:
|
| 33 |
+
cur_dir = os.path.join(data_dir, set_type, level)
|
| 34 |
+
for filename in os.listdir(cur_dir):
|
| 35 |
+
cur_path = os.path.join(cur_dir, filename)
|
| 36 |
+
with open(cur_path, "r") as f:
|
| 37 |
+
cur_data = json.load(f)
|
| 38 |
+
answers = cur_data["answers"]
|
| 39 |
+
options = cur_data["options"]
|
| 40 |
+
questions = cur_data["questions"]
|
| 41 |
+
context = cur_data["article"].replace("\n", " ")
|
| 42 |
+
context = re.sub(r"\s+", " ", context)
|
| 43 |
+
for i in range(len(answers)):
|
| 44 |
+
label = ord(answers[i]) - ord("A")
|
| 45 |
+
qa_list = []
|
| 46 |
+
question = questions[i]
|
| 47 |
+
for j in range(4):
|
| 48 |
+
option = options[i][j]
|
| 49 |
+
if "_" in question:
|
| 50 |
+
qa_cat = question.replace("_", option)
|
| 51 |
+
else:
|
| 52 |
+
qa_cat = " ".join([question, option])
|
| 53 |
+
qa_cat = re.sub(r"\s+", " ", qa_cat)
|
| 54 |
+
qa_list.append(qa_cat)
|
| 55 |
+
examples.append(InputExample(context, qa_list, label))
|
| 56 |
+
|
| 57 |
+
return examples
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main():
|
| 61 |
+
"""
|
| 62 |
+
Helper script to extract paragraphs questions and answers from RACE datasets.
|
| 63 |
+
"""
|
| 64 |
+
parser = argparse.ArgumentParser()
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--input-dir",
|
| 67 |
+
help="input directory for downloaded RACE dataset",
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument(
|
| 70 |
+
"--output-dir",
|
| 71 |
+
help="output directory for extracted data",
|
| 72 |
+
)
|
| 73 |
+
args = parser.parse_args()
|
| 74 |
+
|
| 75 |
+
if not os.path.exists(args.output_dir):
|
| 76 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
for set_type in ["train", "dev", "test-middle", "test-high"]:
|
| 79 |
+
examples = get_examples(args.input_dir, set_type)
|
| 80 |
+
qa_file_paths = [
|
| 81 |
+
os.path.join(args.output_dir, set_type + ".input" + str(i + 1))
|
| 82 |
+
for i in range(4)
|
| 83 |
+
]
|
| 84 |
+
qa_files = [open(qa_file_path, "w") for qa_file_path in qa_file_paths]
|
| 85 |
+
outf_context_path = os.path.join(args.output_dir, set_type + ".input0")
|
| 86 |
+
outf_label_path = os.path.join(args.output_dir, set_type + ".label")
|
| 87 |
+
outf_context = open(outf_context_path, "w")
|
| 88 |
+
outf_label = open(outf_label_path, "w")
|
| 89 |
+
for example in examples:
|
| 90 |
+
outf_context.write(example.paragraph + "\n")
|
| 91 |
+
for i in range(4):
|
| 92 |
+
qa_files[i].write(example.qa_list[i] + "\n")
|
| 93 |
+
outf_label.write(str(example.label) + "\n")
|
| 94 |
+
|
| 95 |
+
for f in qa_files:
|
| 96 |
+
f.close()
|
| 97 |
+
outf_label.close()
|
| 98 |
+
outf_context.close()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|
fairseq-0.10.2/examples/roberta/wsc/README.md
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Finetuning RoBERTa on Winograd Schema Challenge (WSC) data
|
| 2 |
+
|
| 3 |
+
The following instructions can be used to finetune RoBERTa on the WSC training
|
| 4 |
+
data provided by [SuperGLUE](https://super.gluebenchmark.com/).
|
| 5 |
+
|
| 6 |
+
Note that there is high variance in the results. For our GLUE/SuperGLUE
|
| 7 |
+
submission we swept over the learning rate (1e-5, 2e-5, 3e-5), batch size (16,
|
| 8 |
+
32, 64) and total number of updates (500, 1000, 2000, 3000), as well as the
|
| 9 |
+
random seed. Out of ~100 runs we chose the best 7 models and ensembled them.
|
| 10 |
+
|
| 11 |
+
**Approach:** The instructions below use a slightly different loss function than
|
| 12 |
+
what's described in the original RoBERTa arXiv paper. In particular,
|
| 13 |
+
[Kocijan et al. (2019)](https://arxiv.org/abs/1905.06290) introduce a margin
|
| 14 |
+
ranking loss between `(query, candidate)` pairs with tunable hyperparameters
|
| 15 |
+
alpha and beta. This is supported in our code as well with the `--wsc-alpha` and
|
| 16 |
+
`--wsc-beta` arguments. However, we achieved slightly better (and more robust)
|
| 17 |
+
results on the development set by instead using a single cross entropy loss term
|
| 18 |
+
over the log-probabilities for the query and all mined candidates. **The
|
| 19 |
+
candidates are mined using spaCy from each input sentence in isolation, so the
|
| 20 |
+
approach remains strictly pointwise.** This reduces the number of
|
| 21 |
+
hyperparameters and our best model achieved 92.3% development set accuracy,
|
| 22 |
+
compared to ~90% accuracy for the margin loss. Later versions of the RoBERTa
|
| 23 |
+
arXiv paper will describe this updated formulation.
|
| 24 |
+
|
| 25 |
+
### 1) Download the WSC data from the SuperGLUE website:
|
| 26 |
+
```bash
|
| 27 |
+
wget https://dl.fbaipublicfiles.com/glue/superglue/data/v2/WSC.zip
|
| 28 |
+
unzip WSC.zip
|
| 29 |
+
|
| 30 |
+
# we also need to copy the RoBERTa dictionary into the same directory
|
| 31 |
+
wget -O WSC/dict.txt https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### 2) Finetune over the provided training data:
|
| 35 |
+
```bash
|
| 36 |
+
TOTAL_NUM_UPDATES=2000 # Total number of training steps.
|
| 37 |
+
WARMUP_UPDATES=250 # Linearly increase LR over this many steps.
|
| 38 |
+
LR=2e-05 # Peak LR for polynomial LR scheduler.
|
| 39 |
+
MAX_SENTENCES=16 # Batch size per GPU.
|
| 40 |
+
SEED=1 # Random seed.
|
| 41 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
| 42 |
+
|
| 43 |
+
# we use the --user-dir option to load the task and criterion
|
| 44 |
+
# from the examples/roberta/wsc directory:
|
| 45 |
+
FAIRSEQ_PATH=/path/to/fairseq
|
| 46 |
+
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
|
| 47 |
+
|
| 48 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train WSC/ \
|
| 49 |
+
--restore-file $ROBERTA_PATH \
|
| 50 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 51 |
+
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
|
| 52 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
| 53 |
+
--valid-subset val \
|
| 54 |
+
--fp16 --ddp-backend no_c10d \
|
| 55 |
+
--user-dir $FAIRSEQ_USER_DIR \
|
| 56 |
+
--task wsc --criterion wsc --wsc-cross-entropy \
|
| 57 |
+
--arch roberta_large --bpe gpt2 --max-positions 512 \
|
| 58 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
| 59 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
|
| 60 |
+
--lr-scheduler polynomial_decay --lr $LR \
|
| 61 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
|
| 62 |
+
--batch-size $MAX_SENTENCES \
|
| 63 |
+
--max-update $TOTAL_NUM_UPDATES \
|
| 64 |
+
--log-format simple --log-interval 100 \
|
| 65 |
+
--seed $SEED
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
The above command assumes training on 4 GPUs, but you can achieve the same
|
| 69 |
+
results on a single GPU by adding `--update-freq=4`.
|
| 70 |
+
|
| 71 |
+
### 3) Evaluate
|
| 72 |
+
```python
|
| 73 |
+
from fairseq.models.roberta import RobertaModel
|
| 74 |
+
from examples.roberta.wsc import wsc_utils # also loads WSC task and criterion
|
| 75 |
+
roberta = RobertaModel.from_pretrained('checkpoints', 'checkpoint_best.pt', 'WSC/')
|
| 76 |
+
roberta.cuda()
|
| 77 |
+
nsamples, ncorrect = 0, 0
|
| 78 |
+
for sentence, label in wsc_utils.jsonl_iterator('WSC/val.jsonl', eval=True):
|
| 79 |
+
pred = roberta.disambiguate_pronoun(sentence)
|
| 80 |
+
nsamples += 1
|
| 81 |
+
if pred == label:
|
| 82 |
+
ncorrect += 1
|
| 83 |
+
print('Accuracy: ' + str(ncorrect / float(nsamples)))
|
| 84 |
+
# Accuracy: 0.9230769230769231
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
## RoBERTa training on WinoGrande dataset
|
| 88 |
+
We have also provided `winogrande` task and criterion for finetuning on the
|
| 89 |
+
[WinoGrande](https://mosaic.allenai.org/projects/winogrande) like datasets
|
| 90 |
+
where there are always two candidates and one is correct.
|
| 91 |
+
It's more efficient implementation for such subcases.
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
TOTAL_NUM_UPDATES=23750 # Total number of training steps.
|
| 95 |
+
WARMUP_UPDATES=2375 # Linearly increase LR over this many steps.
|
| 96 |
+
LR=1e-05 # Peak LR for polynomial LR scheduler.
|
| 97 |
+
MAX_SENTENCES=32 # Batch size per GPU.
|
| 98 |
+
SEED=1 # Random seed.
|
| 99 |
+
ROBERTA_PATH=/path/to/roberta/model.pt
|
| 100 |
+
|
| 101 |
+
# we use the --user-dir option to load the task and criterion
|
| 102 |
+
# from the examples/roberta/wsc directory:
|
| 103 |
+
FAIRSEQ_PATH=/path/to/fairseq
|
| 104 |
+
FAIRSEQ_USER_DIR=${FAIRSEQ_PATH}/examples/roberta/wsc
|
| 105 |
+
|
| 106 |
+
cd fairseq
|
| 107 |
+
CUDA_VISIBLE_DEVICES=0 fairseq-train winogrande_1.0/ \
|
| 108 |
+
--restore-file $ROBERTA_PATH \
|
| 109 |
+
--reset-optimizer --reset-dataloader --reset-meters \
|
| 110 |
+
--no-epoch-checkpoints --no-last-checkpoints --no-save-optimizer-state \
|
| 111 |
+
--best-checkpoint-metric accuracy --maximize-best-checkpoint-metric \
|
| 112 |
+
--valid-subset val \
|
| 113 |
+
--fp16 --ddp-backend no_c10d \
|
| 114 |
+
--user-dir $FAIRSEQ_USER_DIR \
|
| 115 |
+
--task winogrande --criterion winogrande \
|
| 116 |
+
--wsc-margin-alpha 5.0 --wsc-margin-beta 0.4 \
|
| 117 |
+
--arch roberta_large --bpe gpt2 --max-positions 512 \
|
| 118 |
+
--dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
|
| 119 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-06 \
|
| 120 |
+
--lr-scheduler polynomial_decay --lr $LR \
|
| 121 |
+
--warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_NUM_UPDATES \
|
| 122 |
+
--batch-size $MAX_SENTENCES \
|
| 123 |
+
--max-update $TOTAL_NUM_UPDATES \
|
| 124 |
+
--log-format simple --log-interval 100
|
| 125 |
+
```
|
fairseq-0.10.2/examples/roberta/wsc/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import wsc_criterion # noqa
|
| 7 |
+
from . import wsc_task # noqa
|
fairseq-0.10.2/examples/roberta/wsc/wsc_criterion.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from fairseq import utils
|
| 11 |
+
from fairseq.criterions import LegacyFairseqCriterion, register_criterion
|
| 12 |
+
from fairseq.data import encoders
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register_criterion("wsc")
|
| 16 |
+
class WSCCriterion(LegacyFairseqCriterion):
|
| 17 |
+
def __init__(self, args, task):
|
| 18 |
+
super().__init__(args, task)
|
| 19 |
+
if self.args.save_predictions is not None:
|
| 20 |
+
self.prediction_h = open(self.args.save_predictions, "w")
|
| 21 |
+
else:
|
| 22 |
+
self.prediction_h = None
|
| 23 |
+
self.bpe = encoders.build_bpe(args)
|
| 24 |
+
self.tokenizer = encoders.build_tokenizer(args)
|
| 25 |
+
|
| 26 |
+
def __del__(self):
|
| 27 |
+
if self.prediction_h is not None:
|
| 28 |
+
self.prediction_h.close()
|
| 29 |
+
|
| 30 |
+
@staticmethod
|
| 31 |
+
def add_args(parser):
|
| 32 |
+
"""Add criterion-specific arguments to the parser."""
|
| 33 |
+
parser.add_argument("--wsc-margin-alpha", type=float, metavar="A", default=1.0)
|
| 34 |
+
parser.add_argument("--wsc-margin-beta", type=float, metavar="B", default=0.0)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--wsc-cross-entropy",
|
| 37 |
+
action="store_true",
|
| 38 |
+
help="use cross entropy formulation instead of margin loss",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--save-predictions", metavar="FILE", help="file to save predictions to"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def get_masked_input(self, tokens, mask):
|
| 45 |
+
masked_tokens = tokens.clone()
|
| 46 |
+
masked_tokens[mask] = self.task.mask
|
| 47 |
+
return masked_tokens
|
| 48 |
+
|
| 49 |
+
def get_lprobs(self, model, tokens, mask):
|
| 50 |
+
logits, _ = model(src_tokens=self.get_masked_input(tokens, mask))
|
| 51 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
|
| 52 |
+
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
|
| 53 |
+
mask = mask.type_as(scores)
|
| 54 |
+
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
|
| 55 |
+
return scores
|
| 56 |
+
|
| 57 |
+
def get_loss(self, query_lprobs, cand_lprobs):
|
| 58 |
+
if self.args.wsc_cross_entropy:
|
| 59 |
+
return F.cross_entropy(
|
| 60 |
+
torch.cat([query_lprobs, cand_lprobs]).unsqueeze(0),
|
| 61 |
+
query_lprobs.new([0]).long(),
|
| 62 |
+
)
|
| 63 |
+
else:
|
| 64 |
+
return (
|
| 65 |
+
-query_lprobs
|
| 66 |
+
+ self.args.wsc_margin_alpha
|
| 67 |
+
* (cand_lprobs - query_lprobs + self.args.wsc_margin_beta).clamp(min=0)
|
| 68 |
+
).sum()
|
| 69 |
+
|
| 70 |
+
def forward(self, model, sample, reduce=True):
|
| 71 |
+
# compute loss and accuracy
|
| 72 |
+
loss, nloss = 0.0, 0
|
| 73 |
+
ncorrect, nqueries = 0, 0
|
| 74 |
+
|
| 75 |
+
for i, label in enumerate(sample["labels"]):
|
| 76 |
+
query_lprobs = self.get_lprobs(
|
| 77 |
+
model,
|
| 78 |
+
sample["query_tokens"][i].unsqueeze(0),
|
| 79 |
+
sample["query_masks"][i].unsqueeze(0),
|
| 80 |
+
)
|
| 81 |
+
cand_lprobs = self.get_lprobs(
|
| 82 |
+
model,
|
| 83 |
+
sample["candidate_tokens"][i],
|
| 84 |
+
sample["candidate_masks"][i],
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
pred = (query_lprobs >= cand_lprobs).all().item()
|
| 88 |
+
|
| 89 |
+
if label is not None:
|
| 90 |
+
label = 1 if label else 0
|
| 91 |
+
ncorrect += 1 if pred == label else 0
|
| 92 |
+
nqueries += 1
|
| 93 |
+
|
| 94 |
+
if label:
|
| 95 |
+
# only compute a loss for positive instances
|
| 96 |
+
nloss += 1
|
| 97 |
+
loss += self.get_loss(query_lprobs, cand_lprobs)
|
| 98 |
+
|
| 99 |
+
id = sample["id"][i].item()
|
| 100 |
+
if self.prediction_h is not None:
|
| 101 |
+
print("{}\t{}\t{}".format(id, pred, label), file=self.prediction_h)
|
| 102 |
+
|
| 103 |
+
if nloss == 0:
|
| 104 |
+
loss = torch.tensor(0.0, requires_grad=True)
|
| 105 |
+
|
| 106 |
+
sample_size = nqueries if nqueries > 0 else 1
|
| 107 |
+
logging_output = {
|
| 108 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
| 109 |
+
"ntokens": sample["ntokens"],
|
| 110 |
+
"nsentences": sample["nsentences"],
|
| 111 |
+
"sample_size": sample_size,
|
| 112 |
+
"ncorrect": ncorrect,
|
| 113 |
+
"nqueries": nqueries,
|
| 114 |
+
}
|
| 115 |
+
return loss, sample_size, logging_output
|
| 116 |
+
|
| 117 |
+
@staticmethod
|
| 118 |
+
def aggregate_logging_outputs(logging_outputs):
|
| 119 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 120 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 121 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 122 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 123 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 124 |
+
|
| 125 |
+
agg_output = {
|
| 126 |
+
"loss": loss_sum / sample_size / math.log(2),
|
| 127 |
+
"ntokens": ntokens,
|
| 128 |
+
"nsentences": nsentences,
|
| 129 |
+
"sample_size": sample_size,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
ncorrect = sum(log.get("ncorrect", 0) for log in logging_outputs)
|
| 133 |
+
nqueries = sum(log.get("nqueries", 0) for log in logging_outputs)
|
| 134 |
+
if nqueries > 0:
|
| 135 |
+
agg_output["accuracy"] = ncorrect / float(nqueries)
|
| 136 |
+
|
| 137 |
+
return agg_output
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@register_criterion("winogrande")
|
| 141 |
+
class WinograndeCriterion(WSCCriterion):
|
| 142 |
+
def forward(self, model, sample, reduce=True):
|
| 143 |
+
# compute loss and accuracy
|
| 144 |
+
query_lprobs = self.get_lprobs(
|
| 145 |
+
model,
|
| 146 |
+
sample["query_tokens"],
|
| 147 |
+
sample["query_masks"],
|
| 148 |
+
)
|
| 149 |
+
cand_lprobs = self.get_lprobs(
|
| 150 |
+
model,
|
| 151 |
+
sample["candidate_tokens"],
|
| 152 |
+
sample["candidate_masks"],
|
| 153 |
+
)
|
| 154 |
+
pred = query_lprobs >= cand_lprobs
|
| 155 |
+
loss = self.get_loss(query_lprobs, cand_lprobs)
|
| 156 |
+
|
| 157 |
+
sample_size = sample["query_tokens"].size(0)
|
| 158 |
+
ncorrect = pred.sum().item()
|
| 159 |
+
logging_output = {
|
| 160 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
| 161 |
+
"ntokens": sample["ntokens"],
|
| 162 |
+
"nsentences": sample["nsentences"],
|
| 163 |
+
"sample_size": sample_size,
|
| 164 |
+
"ncorrect": ncorrect,
|
| 165 |
+
"nqueries": sample_size,
|
| 166 |
+
}
|
| 167 |
+
return loss, sample_size, logging_output
|
fairseq-0.10.2/examples/roberta/wsc/wsc_task.py
ADDED
|
@@ -0,0 +1,401 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
import tempfile
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from fairseq import utils
|
| 14 |
+
from fairseq.data import (
|
| 15 |
+
Dictionary,
|
| 16 |
+
IdDataset,
|
| 17 |
+
ListDataset,
|
| 18 |
+
NestedDictionaryDataset,
|
| 19 |
+
NumelDataset,
|
| 20 |
+
NumSamplesDataset,
|
| 21 |
+
PadDataset,
|
| 22 |
+
SortDataset,
|
| 23 |
+
data_utils,
|
| 24 |
+
encoders,
|
| 25 |
+
)
|
| 26 |
+
from fairseq.tasks import LegacyFairseqTask, register_task
|
| 27 |
+
|
| 28 |
+
from . import wsc_utils
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@register_task("wsc")
|
| 32 |
+
class WSCTask(LegacyFairseqTask):
|
| 33 |
+
"""Task to finetune RoBERTa for Winograd Schemas."""
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def add_args(parser):
|
| 37 |
+
"""Add task-specific arguments to the parser."""
|
| 38 |
+
parser.add_argument(
|
| 39 |
+
"data", metavar="DIR", help="path to data directory; we load <split>.jsonl"
|
| 40 |
+
)
|
| 41 |
+
parser.add_argument(
|
| 42 |
+
"--init-token",
|
| 43 |
+
type=int,
|
| 44 |
+
default=None,
|
| 45 |
+
help="add token at the beginning of each batch item",
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def __init__(self, args, vocab):
|
| 49 |
+
super().__init__(args)
|
| 50 |
+
self.vocab = vocab
|
| 51 |
+
self.mask = vocab.add_symbol("<mask>")
|
| 52 |
+
|
| 53 |
+
self.bpe = encoders.build_bpe(args)
|
| 54 |
+
self.tokenizer = encoders.build_tokenizer(args)
|
| 55 |
+
|
| 56 |
+
# hack to handle GPT-2 BPE, which includes leading spaces
|
| 57 |
+
if args.bpe == "gpt2":
|
| 58 |
+
self.leading_space = True
|
| 59 |
+
self.trailing_space = False
|
| 60 |
+
else:
|
| 61 |
+
self.leading_space = False
|
| 62 |
+
self.trailing_space = True
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def load_dictionary(cls, filename):
|
| 66 |
+
"""Load the dictionary from the filename
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
filename (str): the filename
|
| 70 |
+
"""
|
| 71 |
+
dictionary = Dictionary.load(filename)
|
| 72 |
+
dictionary.add_symbol("<mask>")
|
| 73 |
+
return dictionary
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def setup_task(cls, args, **kwargs):
|
| 77 |
+
assert args.criterion == "wsc", "Must set --criterion=wsc"
|
| 78 |
+
|
| 79 |
+
# load data and label dictionaries
|
| 80 |
+
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
| 81 |
+
print("| dictionary: {} types".format(len(vocab)))
|
| 82 |
+
|
| 83 |
+
return cls(args, vocab)
|
| 84 |
+
|
| 85 |
+
def binarize(self, s: str, append_eos: bool = False):
|
| 86 |
+
if self.tokenizer is not None:
|
| 87 |
+
s = self.tokenizer.encode(s)
|
| 88 |
+
if self.bpe is not None:
|
| 89 |
+
s = self.bpe.encode(s)
|
| 90 |
+
tokens = self.vocab.encode_line(
|
| 91 |
+
s,
|
| 92 |
+
append_eos=append_eos,
|
| 93 |
+
add_if_not_exist=False,
|
| 94 |
+
).long()
|
| 95 |
+
if self.args.init_token is not None:
|
| 96 |
+
tokens = torch.cat([tokens.new([self.args.init_token]), tokens])
|
| 97 |
+
return tokens
|
| 98 |
+
|
| 99 |
+
def binarize_with_mask(self, txt, prefix, suffix, leading_space, trailing_space):
|
| 100 |
+
toks = self.binarize(
|
| 101 |
+
prefix + leading_space + txt + trailing_space + suffix,
|
| 102 |
+
append_eos=True,
|
| 103 |
+
)
|
| 104 |
+
mask = torch.zeros_like(toks, dtype=torch.bool)
|
| 105 |
+
mask_start = len(self.binarize(prefix))
|
| 106 |
+
mask_size = len(self.binarize(leading_space + txt))
|
| 107 |
+
mask[mask_start : mask_start + mask_size] = 1
|
| 108 |
+
return toks, mask
|
| 109 |
+
|
| 110 |
+
def load_dataset(
|
| 111 |
+
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
| 112 |
+
):
|
| 113 |
+
"""Load a given dataset split.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 117 |
+
"""
|
| 118 |
+
if data_path is None:
|
| 119 |
+
data_path = os.path.join(self.args.data, split + ".jsonl")
|
| 120 |
+
if not os.path.exists(data_path):
|
| 121 |
+
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
| 122 |
+
|
| 123 |
+
query_tokens = []
|
| 124 |
+
query_masks = []
|
| 125 |
+
query_lengths = []
|
| 126 |
+
candidate_tokens = []
|
| 127 |
+
candidate_masks = []
|
| 128 |
+
candidate_lengths = []
|
| 129 |
+
labels = []
|
| 130 |
+
|
| 131 |
+
for sentence, pronoun_span, query, label in wsc_utils.jsonl_iterator(data_path):
|
| 132 |
+
prefix = sentence[: pronoun_span.start].text
|
| 133 |
+
suffix = sentence[pronoun_span.end :].text_with_ws
|
| 134 |
+
|
| 135 |
+
# spaCy spans include trailing spaces, but we need to know about
|
| 136 |
+
# leading spaces for the GPT-2 BPE
|
| 137 |
+
leading_space = (
|
| 138 |
+
" " if sentence[: pronoun_span.start].text_with_ws.endswith(" ") else ""
|
| 139 |
+
)
|
| 140 |
+
trailing_space = " " if pronoun_span.text_with_ws.endswith(" ") else ""
|
| 141 |
+
|
| 142 |
+
# get noun phrases, excluding pronouns and anything overlapping with the query
|
| 143 |
+
cand_spans = wsc_utils.filter_noun_chunks(
|
| 144 |
+
wsc_utils.extended_noun_chunks(sentence),
|
| 145 |
+
exclude_pronouns=True,
|
| 146 |
+
exclude_query=query,
|
| 147 |
+
exact_match=False,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
if query is not None:
|
| 151 |
+
query_toks, query_mask = self.binarize_with_mask(
|
| 152 |
+
query, prefix, suffix, leading_space, trailing_space
|
| 153 |
+
)
|
| 154 |
+
query_len = len(query_toks)
|
| 155 |
+
else:
|
| 156 |
+
query_toks, query_mask, query_len = None, None, 0
|
| 157 |
+
|
| 158 |
+
query_tokens.append(query_toks)
|
| 159 |
+
query_masks.append(query_mask)
|
| 160 |
+
query_lengths.append(query_len)
|
| 161 |
+
|
| 162 |
+
cand_toks, cand_masks = [], []
|
| 163 |
+
for cand_span in cand_spans:
|
| 164 |
+
toks, mask = self.binarize_with_mask(
|
| 165 |
+
cand_span.text,
|
| 166 |
+
prefix,
|
| 167 |
+
suffix,
|
| 168 |
+
leading_space,
|
| 169 |
+
trailing_space,
|
| 170 |
+
)
|
| 171 |
+
cand_toks.append(toks)
|
| 172 |
+
cand_masks.append(mask)
|
| 173 |
+
|
| 174 |
+
# collate candidates
|
| 175 |
+
cand_toks = data_utils.collate_tokens(cand_toks, pad_idx=self.vocab.pad())
|
| 176 |
+
cand_masks = data_utils.collate_tokens(cand_masks, pad_idx=0)
|
| 177 |
+
assert cand_toks.size() == cand_masks.size()
|
| 178 |
+
|
| 179 |
+
candidate_tokens.append(cand_toks)
|
| 180 |
+
candidate_masks.append(cand_masks)
|
| 181 |
+
candidate_lengths.append(cand_toks.size(1))
|
| 182 |
+
|
| 183 |
+
labels.append(label)
|
| 184 |
+
|
| 185 |
+
query_lengths = np.array(query_lengths)
|
| 186 |
+
query_tokens = ListDataset(query_tokens, query_lengths)
|
| 187 |
+
query_masks = ListDataset(query_masks, query_lengths)
|
| 188 |
+
|
| 189 |
+
candidate_lengths = np.array(candidate_lengths)
|
| 190 |
+
candidate_tokens = ListDataset(candidate_tokens, candidate_lengths)
|
| 191 |
+
candidate_masks = ListDataset(candidate_masks, candidate_lengths)
|
| 192 |
+
|
| 193 |
+
labels = ListDataset(labels, [1] * len(labels))
|
| 194 |
+
|
| 195 |
+
dataset = {
|
| 196 |
+
"id": IdDataset(),
|
| 197 |
+
"query_tokens": query_tokens,
|
| 198 |
+
"query_masks": query_masks,
|
| 199 |
+
"candidate_tokens": candidate_tokens,
|
| 200 |
+
"candidate_masks": candidate_masks,
|
| 201 |
+
"labels": labels,
|
| 202 |
+
"nsentences": NumSamplesDataset(),
|
| 203 |
+
"ntokens": NumelDataset(query_tokens, reduce=True),
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
nested_dataset = NestedDictionaryDataset(
|
| 207 |
+
dataset,
|
| 208 |
+
sizes=[query_lengths],
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
with data_utils.numpy_seed(self.args.seed):
|
| 212 |
+
shuffle = np.random.permutation(len(query_tokens))
|
| 213 |
+
dataset = SortDataset(
|
| 214 |
+
nested_dataset,
|
| 215 |
+
# shuffle
|
| 216 |
+
sort_order=[shuffle],
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
if return_only:
|
| 220 |
+
return dataset
|
| 221 |
+
|
| 222 |
+
self.datasets[split] = dataset
|
| 223 |
+
return self.datasets[split]
|
| 224 |
+
|
| 225 |
+
def build_dataset_for_inference(self, sample_json):
|
| 226 |
+
with tempfile.NamedTemporaryFile(buffering=0) as h:
|
| 227 |
+
h.write((json.dumps(sample_json) + "\n").encode("utf-8"))
|
| 228 |
+
dataset = self.load_dataset(
|
| 229 |
+
"disambiguate_pronoun",
|
| 230 |
+
data_path=h.name,
|
| 231 |
+
return_only=True,
|
| 232 |
+
)
|
| 233 |
+
return dataset
|
| 234 |
+
|
| 235 |
+
def disambiguate_pronoun(self, model, sentence, use_cuda=False):
|
| 236 |
+
sample_json = wsc_utils.convert_sentence_to_json(sentence)
|
| 237 |
+
dataset = self.build_dataset_for_inference(sample_json)
|
| 238 |
+
sample = dataset.collater([dataset[0]])
|
| 239 |
+
if use_cuda:
|
| 240 |
+
sample = utils.move_to_cuda(sample)
|
| 241 |
+
|
| 242 |
+
def get_masked_input(tokens, mask):
|
| 243 |
+
masked_tokens = tokens.clone()
|
| 244 |
+
masked_tokens[mask.bool()] = self.mask
|
| 245 |
+
return masked_tokens
|
| 246 |
+
|
| 247 |
+
def get_lprobs(tokens, mask):
|
| 248 |
+
logits, _ = model(src_tokens=get_masked_input(tokens, mask))
|
| 249 |
+
lprobs = F.log_softmax(logits, dim=-1, dtype=torch.float)
|
| 250 |
+
scores = lprobs.gather(2, tokens.unsqueeze(-1)).squeeze(-1)
|
| 251 |
+
mask = mask.type_as(scores)
|
| 252 |
+
scores = (scores * mask).sum(dim=-1) / mask.sum(dim=-1)
|
| 253 |
+
return scores
|
| 254 |
+
|
| 255 |
+
cand_lprobs = get_lprobs(
|
| 256 |
+
sample["candidate_tokens"][0],
|
| 257 |
+
sample["candidate_masks"][0],
|
| 258 |
+
)
|
| 259 |
+
if sample["query_tokens"][0] is not None:
|
| 260 |
+
query_lprobs = get_lprobs(
|
| 261 |
+
sample["query_tokens"][0].unsqueeze(0),
|
| 262 |
+
sample["query_masks"][0].unsqueeze(0),
|
| 263 |
+
)
|
| 264 |
+
return (query_lprobs >= cand_lprobs).all().item() == 1
|
| 265 |
+
else:
|
| 266 |
+
best_idx = cand_lprobs.argmax().item()
|
| 267 |
+
full_cand = sample["candidate_tokens"][0][best_idx]
|
| 268 |
+
mask = sample["candidate_masks"][0][best_idx]
|
| 269 |
+
toks = full_cand[mask.bool()]
|
| 270 |
+
return self.bpe.decode(self.source_dictionary.string(toks)).strip()
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def source_dictionary(self):
|
| 274 |
+
return self.vocab
|
| 275 |
+
|
| 276 |
+
@property
|
| 277 |
+
def target_dictionary(self):
|
| 278 |
+
return self.vocab
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
@register_task("winogrande")
|
| 282 |
+
class WinograndeTask(WSCTask):
|
| 283 |
+
"""
|
| 284 |
+
Task for WinoGrande dataset. Efficient implementation for Winograd schema
|
| 285 |
+
tasks with exactly two candidates, one of which is correct.
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
@classmethod
|
| 289 |
+
def setup_task(cls, args, **kwargs):
|
| 290 |
+
assert args.criterion == "winogrande", "Must set --criterion=winogrande"
|
| 291 |
+
|
| 292 |
+
# load data and label dictionaries
|
| 293 |
+
vocab = cls.load_dictionary(os.path.join(args.data, "dict.txt"))
|
| 294 |
+
print("| dictionary: {} types".format(len(vocab)))
|
| 295 |
+
|
| 296 |
+
return cls(args, vocab)
|
| 297 |
+
|
| 298 |
+
def load_dataset(
|
| 299 |
+
self, split, epoch=1, combine=False, data_path=None, return_only=False, **kwargs
|
| 300 |
+
):
|
| 301 |
+
"""Load a given dataset split.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
split (str): name of the split (e.g., train, valid, test)
|
| 305 |
+
"""
|
| 306 |
+
if data_path is None:
|
| 307 |
+
data_path = os.path.join(self.args.data, split + ".jsonl")
|
| 308 |
+
if not os.path.exists(data_path):
|
| 309 |
+
raise FileNotFoundError("Cannot find data: {}".format(data_path))
|
| 310 |
+
|
| 311 |
+
query_tokens = []
|
| 312 |
+
query_masks = []
|
| 313 |
+
query_lengths = []
|
| 314 |
+
candidate_tokens = []
|
| 315 |
+
candidate_masks = []
|
| 316 |
+
candidate_lengths = []
|
| 317 |
+
|
| 318 |
+
itr = wsc_utils.winogrande_jsonl_iterator(data_path, eval=(split == "test"))
|
| 319 |
+
|
| 320 |
+
for sample in itr:
|
| 321 |
+
sentence, pronoun_span, query, cand_text = sample
|
| 322 |
+
prefix = sentence[: pronoun_span[0]].rstrip()
|
| 323 |
+
suffix = sentence[pronoun_span[1] :]
|
| 324 |
+
|
| 325 |
+
leading_space = " " if sentence[: pronoun_span[0]].endswith(" ") else ""
|
| 326 |
+
trailing_space = ""
|
| 327 |
+
|
| 328 |
+
if query is not None:
|
| 329 |
+
query_toks, query_mask = self.binarize_with_mask(
|
| 330 |
+
query,
|
| 331 |
+
prefix,
|
| 332 |
+
suffix,
|
| 333 |
+
leading_space,
|
| 334 |
+
trailing_space,
|
| 335 |
+
)
|
| 336 |
+
query_len = len(query_toks)
|
| 337 |
+
else:
|
| 338 |
+
query_toks, query_mask, query_len = None, None, 0
|
| 339 |
+
|
| 340 |
+
query_tokens.append(query_toks)
|
| 341 |
+
query_masks.append(query_mask)
|
| 342 |
+
query_lengths.append(query_len)
|
| 343 |
+
|
| 344 |
+
cand_toks, cand_mask = self.binarize_with_mask(
|
| 345 |
+
cand_text,
|
| 346 |
+
prefix,
|
| 347 |
+
suffix,
|
| 348 |
+
leading_space,
|
| 349 |
+
trailing_space,
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
candidate_tokens.append(cand_toks)
|
| 353 |
+
candidate_masks.append(cand_mask)
|
| 354 |
+
candidate_lengths.append(cand_toks.size(0))
|
| 355 |
+
|
| 356 |
+
query_lengths = np.array(query_lengths)
|
| 357 |
+
|
| 358 |
+
def get_pad_dataset_fn(tokens, length, pad_idx):
|
| 359 |
+
return PadDataset(
|
| 360 |
+
ListDataset(tokens, length),
|
| 361 |
+
pad_idx=pad_idx,
|
| 362 |
+
left_pad=False,
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
query_tokens = get_pad_dataset_fn(query_tokens, query_lengths, self.vocab.pad())
|
| 366 |
+
query_masks = get_pad_dataset_fn(query_masks, query_lengths, 0)
|
| 367 |
+
|
| 368 |
+
candidate_lengths = np.array(candidate_lengths)
|
| 369 |
+
candidate_tokens = get_pad_dataset_fn(
|
| 370 |
+
candidate_tokens, candidate_lengths, self.vocab.pad()
|
| 371 |
+
)
|
| 372 |
+
candidate_masks = get_pad_dataset_fn(candidate_masks, candidate_lengths, 0)
|
| 373 |
+
|
| 374 |
+
dataset = {
|
| 375 |
+
"id": IdDataset(),
|
| 376 |
+
"query_tokens": query_tokens,
|
| 377 |
+
"query_masks": query_masks,
|
| 378 |
+
"candidate_tokens": candidate_tokens,
|
| 379 |
+
"candidate_masks": candidate_masks,
|
| 380 |
+
"nsentences": NumSamplesDataset(),
|
| 381 |
+
"ntokens": NumelDataset(query_tokens, reduce=True),
|
| 382 |
+
}
|
| 383 |
+
|
| 384 |
+
nested_dataset = NestedDictionaryDataset(
|
| 385 |
+
dataset,
|
| 386 |
+
sizes=[query_lengths],
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
with data_utils.numpy_seed(self.args.seed):
|
| 390 |
+
shuffle = np.random.permutation(len(query_tokens))
|
| 391 |
+
dataset = SortDataset(
|
| 392 |
+
nested_dataset,
|
| 393 |
+
# shuffle
|
| 394 |
+
sort_order=[shuffle],
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
if return_only:
|
| 398 |
+
return dataset
|
| 399 |
+
|
| 400 |
+
self.datasets[split] = dataset
|
| 401 |
+
return self.datasets[split]
|
fairseq-0.10.2/examples/roberta/wsc/wsc_utils.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
from functools import lru_cache
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def convert_sentence_to_json(sentence):
|
| 11 |
+
if "_" in sentence:
|
| 12 |
+
prefix, rest = sentence.split("_", 1)
|
| 13 |
+
query, rest = rest.split("_", 1)
|
| 14 |
+
query_index = len(prefix.rstrip().split(" "))
|
| 15 |
+
else:
|
| 16 |
+
query, query_index = None, None
|
| 17 |
+
|
| 18 |
+
prefix, rest = sentence.split("[", 1)
|
| 19 |
+
pronoun, rest = rest.split("]", 1)
|
| 20 |
+
pronoun_index = len(prefix.rstrip().split(" "))
|
| 21 |
+
|
| 22 |
+
sentence = sentence.replace("_", "").replace("[", "").replace("]", "")
|
| 23 |
+
|
| 24 |
+
return {
|
| 25 |
+
"idx": 0,
|
| 26 |
+
"text": sentence,
|
| 27 |
+
"target": {
|
| 28 |
+
"span1_index": query_index,
|
| 29 |
+
"span1_text": query,
|
| 30 |
+
"span2_index": pronoun_index,
|
| 31 |
+
"span2_text": pronoun,
|
| 32 |
+
},
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def extended_noun_chunks(sentence):
|
| 37 |
+
noun_chunks = {(np.start, np.end) for np in sentence.noun_chunks}
|
| 38 |
+
np_start, cur_np = 0, "NONE"
|
| 39 |
+
for i, token in enumerate(sentence):
|
| 40 |
+
np_type = token.pos_ if token.pos_ in {"NOUN", "PROPN"} else "NONE"
|
| 41 |
+
if np_type != cur_np:
|
| 42 |
+
if cur_np != "NONE":
|
| 43 |
+
noun_chunks.add((np_start, i))
|
| 44 |
+
if np_type != "NONE":
|
| 45 |
+
np_start = i
|
| 46 |
+
cur_np = np_type
|
| 47 |
+
if cur_np != "NONE":
|
| 48 |
+
noun_chunks.add((np_start, len(sentence)))
|
| 49 |
+
return [sentence[s:e] for (s, e) in sorted(noun_chunks)]
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def find_token(sentence, start_pos):
|
| 53 |
+
found_tok = None
|
| 54 |
+
for tok in sentence:
|
| 55 |
+
if tok.idx == start_pos:
|
| 56 |
+
found_tok = tok
|
| 57 |
+
break
|
| 58 |
+
return found_tok
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def find_span(sentence, search_text, start=0):
|
| 62 |
+
search_text = search_text.lower()
|
| 63 |
+
for tok in sentence[start:]:
|
| 64 |
+
remainder = sentence[tok.i :].text.lower()
|
| 65 |
+
if remainder.startswith(search_text):
|
| 66 |
+
len_to_consume = len(search_text)
|
| 67 |
+
start_idx = tok.idx
|
| 68 |
+
for next_tok in sentence[tok.i :]:
|
| 69 |
+
end_idx = next_tok.idx + len(next_tok.text)
|
| 70 |
+
if end_idx - start_idx == len_to_consume:
|
| 71 |
+
span = sentence[tok.i : next_tok.i + 1]
|
| 72 |
+
return span
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@lru_cache(maxsize=1)
|
| 77 |
+
def get_detokenizer():
|
| 78 |
+
from sacremoses import MosesDetokenizer
|
| 79 |
+
|
| 80 |
+
detok = MosesDetokenizer(lang="en")
|
| 81 |
+
return detok
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@lru_cache(maxsize=1)
|
| 85 |
+
def get_spacy_nlp():
|
| 86 |
+
import en_core_web_lg
|
| 87 |
+
|
| 88 |
+
nlp = en_core_web_lg.load()
|
| 89 |
+
return nlp
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def jsonl_iterator(input_fname, positive_only=False, ngram_order=3, eval=False):
|
| 93 |
+
detok = get_detokenizer()
|
| 94 |
+
nlp = get_spacy_nlp()
|
| 95 |
+
|
| 96 |
+
with open(input_fname) as fin:
|
| 97 |
+
for line in fin:
|
| 98 |
+
sample = json.loads(line.strip())
|
| 99 |
+
|
| 100 |
+
if positive_only and "label" in sample and not sample["label"]:
|
| 101 |
+
# only consider examples where the query is correct
|
| 102 |
+
continue
|
| 103 |
+
|
| 104 |
+
target = sample["target"]
|
| 105 |
+
|
| 106 |
+
# clean up the query
|
| 107 |
+
query = target["span1_text"]
|
| 108 |
+
if query is not None:
|
| 109 |
+
if "\n" in query:
|
| 110 |
+
continue
|
| 111 |
+
if query.endswith(".") or query.endswith(","):
|
| 112 |
+
query = query[:-1]
|
| 113 |
+
|
| 114 |
+
# split tokens
|
| 115 |
+
tokens = sample["text"].split(" ")
|
| 116 |
+
|
| 117 |
+
def strip_pronoun(x):
|
| 118 |
+
return x.rstrip('.,"')
|
| 119 |
+
|
| 120 |
+
# find the pronoun
|
| 121 |
+
pronoun_idx = target["span2_index"]
|
| 122 |
+
pronoun = strip_pronoun(target["span2_text"])
|
| 123 |
+
if strip_pronoun(tokens[pronoun_idx]) != pronoun:
|
| 124 |
+
# hack: sometimes the index is misaligned
|
| 125 |
+
if strip_pronoun(tokens[pronoun_idx + 1]) == pronoun:
|
| 126 |
+
pronoun_idx += 1
|
| 127 |
+
else:
|
| 128 |
+
raise Exception("Misaligned pronoun!")
|
| 129 |
+
assert strip_pronoun(tokens[pronoun_idx]) == pronoun
|
| 130 |
+
|
| 131 |
+
# split tokens before and after the pronoun
|
| 132 |
+
before = tokens[:pronoun_idx]
|
| 133 |
+
after = tokens[pronoun_idx + 1 :]
|
| 134 |
+
|
| 135 |
+
# the GPT BPE attaches leading spaces to tokens, so we keep track
|
| 136 |
+
# of whether we need spaces before or after the pronoun
|
| 137 |
+
leading_space = " " if pronoun_idx > 0 else ""
|
| 138 |
+
trailing_space = " " if len(after) > 0 else ""
|
| 139 |
+
|
| 140 |
+
# detokenize
|
| 141 |
+
before = detok.detokenize(before, return_str=True)
|
| 142 |
+
pronoun = detok.detokenize([pronoun], return_str=True)
|
| 143 |
+
after = detok.detokenize(after, return_str=True)
|
| 144 |
+
|
| 145 |
+
# hack: when the pronoun ends in a period (or comma), move the
|
| 146 |
+
# punctuation to the "after" part
|
| 147 |
+
if pronoun.endswith(".") or pronoun.endswith(","):
|
| 148 |
+
after = pronoun[-1] + trailing_space + after
|
| 149 |
+
pronoun = pronoun[:-1]
|
| 150 |
+
|
| 151 |
+
# hack: when the "after" part begins with a comma or period, remove
|
| 152 |
+
# the trailing space
|
| 153 |
+
if after.startswith(".") or after.startswith(","):
|
| 154 |
+
trailing_space = ""
|
| 155 |
+
|
| 156 |
+
# parse sentence with spacy
|
| 157 |
+
sentence = nlp(before + leading_space + pronoun + trailing_space + after)
|
| 158 |
+
|
| 159 |
+
# find pronoun span
|
| 160 |
+
start = len(before + leading_space)
|
| 161 |
+
first_pronoun_tok = find_token(sentence, start_pos=start)
|
| 162 |
+
pronoun_span = find_span(sentence, pronoun, start=first_pronoun_tok.i)
|
| 163 |
+
assert pronoun_span.text == pronoun
|
| 164 |
+
|
| 165 |
+
if eval:
|
| 166 |
+
# convert to format where pronoun is surrounded by "[]" and
|
| 167 |
+
# query is surrounded by "_"
|
| 168 |
+
query_span = find_span(sentence, query)
|
| 169 |
+
query_with_ws = "_{}_{}".format(
|
| 170 |
+
query_span.text,
|
| 171 |
+
(" " if query_span.text_with_ws.endswith(" ") else ""),
|
| 172 |
+
)
|
| 173 |
+
pronoun_with_ws = "[{}]{}".format(
|
| 174 |
+
pronoun_span.text,
|
| 175 |
+
(" " if pronoun_span.text_with_ws.endswith(" ") else ""),
|
| 176 |
+
)
|
| 177 |
+
if query_span.start < pronoun_span.start:
|
| 178 |
+
first = (query_span, query_with_ws)
|
| 179 |
+
second = (pronoun_span, pronoun_with_ws)
|
| 180 |
+
else:
|
| 181 |
+
first = (pronoun_span, pronoun_with_ws)
|
| 182 |
+
second = (query_span, query_with_ws)
|
| 183 |
+
sentence = (
|
| 184 |
+
sentence[: first[0].start].text_with_ws
|
| 185 |
+
+ first[1]
|
| 186 |
+
+ sentence[first[0].end : second[0].start].text_with_ws
|
| 187 |
+
+ second[1]
|
| 188 |
+
+ sentence[second[0].end :].text
|
| 189 |
+
)
|
| 190 |
+
yield sentence, sample.get("label", None)
|
| 191 |
+
else:
|
| 192 |
+
yield sentence, pronoun_span, query, sample.get("label", None)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def winogrande_jsonl_iterator(input_fname, eval=False):
|
| 196 |
+
with open(input_fname) as fin:
|
| 197 |
+
for line in fin:
|
| 198 |
+
sample = json.loads(line.strip())
|
| 199 |
+
sentence, option1, option2 = (
|
| 200 |
+
sample["sentence"],
|
| 201 |
+
sample["option1"],
|
| 202 |
+
sample["option2"],
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
pronoun_span = (sentence.index("_"), sentence.index("_") + 1)
|
| 206 |
+
|
| 207 |
+
if eval:
|
| 208 |
+
query, cand = option1, option2
|
| 209 |
+
else:
|
| 210 |
+
query = option1 if sample["answer"] == "1" else option2
|
| 211 |
+
cand = option2 if sample["answer"] == "1" else option1
|
| 212 |
+
yield sentence, pronoun_span, query, cand
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def filter_noun_chunks(
|
| 216 |
+
chunks, exclude_pronouns=False, exclude_query=None, exact_match=False
|
| 217 |
+
):
|
| 218 |
+
if exclude_pronouns:
|
| 219 |
+
chunks = [
|
| 220 |
+
np
|
| 221 |
+
for np in chunks
|
| 222 |
+
if (np.lemma_ != "-PRON-" and not all(tok.pos_ == "PRON" for tok in np))
|
| 223 |
+
]
|
| 224 |
+
|
| 225 |
+
if exclude_query is not None:
|
| 226 |
+
excl_txt = [exclude_query.lower()]
|
| 227 |
+
filtered_chunks = []
|
| 228 |
+
for chunk in chunks:
|
| 229 |
+
lower_chunk = chunk.text.lower()
|
| 230 |
+
found = False
|
| 231 |
+
for excl in excl_txt:
|
| 232 |
+
if (
|
| 233 |
+
not exact_match and (lower_chunk in excl or excl in lower_chunk)
|
| 234 |
+
) or lower_chunk == excl:
|
| 235 |
+
found = True
|
| 236 |
+
break
|
| 237 |
+
if not found:
|
| 238 |
+
filtered_chunks.append(chunk)
|
| 239 |
+
chunks = filtered_chunks
|
| 240 |
+
|
| 241 |
+
return chunks
|
fairseq-0.10.2/examples/scaling_nmt/README.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Scaling Neural Machine Translation (Ott et al., 2018)
|
| 2 |
+
|
| 3 |
+
This page includes instructions for reproducing results from the paper [Scaling Neural Machine Translation (Ott et al., 2018)](https://arxiv.org/abs/1806.00187).
|
| 4 |
+
|
| 5 |
+
## Pre-trained models
|
| 6 |
+
|
| 7 |
+
Model | Description | Dataset | Download
|
| 8 |
+
---|---|---|---
|
| 9 |
+
`transformer.wmt14.en-fr` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT14 English-French](http://statmt.org/wmt14/translation-task.html#Download) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt14.en-fr.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt14.en-fr.joined-dict.newstest2014.tar.bz2)
|
| 10 |
+
`transformer.wmt16.en-de` | Transformer <br> ([Ott et al., 2018](https://arxiv.org/abs/1806.00187)) | [WMT16 English-German](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8) | model: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/models/wmt16.en-de.joined-dict.transformer.tar.bz2) <br> newstest2014: <br> [download (.tar.bz2)](https://dl.fbaipublicfiles.com/fairseq/data/wmt16.en-de.joined-dict.newstest2014.tar.bz2)
|
| 11 |
+
|
| 12 |
+
## Training a new model on WMT'16 En-De
|
| 13 |
+
|
| 14 |
+
First download the [preprocessed WMT'16 En-De data provided by Google](https://drive.google.com/uc?export=download&id=0B_bZck-ksdkpM25jRUN2X2UxMm8).
|
| 15 |
+
|
| 16 |
+
Then:
|
| 17 |
+
|
| 18 |
+
##### 1. Extract the WMT'16 En-De data
|
| 19 |
+
```bash
|
| 20 |
+
TEXT=wmt16_en_de_bpe32k
|
| 21 |
+
mkdir -p $TEXT
|
| 22 |
+
tar -xzvf wmt16_en_de.tar.gz -C $TEXT
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
##### 2. Preprocess the dataset with a joined dictionary
|
| 26 |
+
```bash
|
| 27 |
+
fairseq-preprocess \
|
| 28 |
+
--source-lang en --target-lang de \
|
| 29 |
+
--trainpref $TEXT/train.tok.clean.bpe.32000 \
|
| 30 |
+
--validpref $TEXT/newstest2013.tok.bpe.32000 \
|
| 31 |
+
--testpref $TEXT/newstest2014.tok.bpe.32000 \
|
| 32 |
+
--destdir data-bin/wmt16_en_de_bpe32k \
|
| 33 |
+
--nwordssrc 32768 --nwordstgt 32768 \
|
| 34 |
+
--joined-dictionary \
|
| 35 |
+
--workers 20
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
##### 3. Train a model
|
| 39 |
+
```bash
|
| 40 |
+
fairseq-train \
|
| 41 |
+
data-bin/wmt16_en_de_bpe32k \
|
| 42 |
+
--arch transformer_vaswani_wmt_en_de_big --share-all-embeddings \
|
| 43 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
|
| 44 |
+
--lr 0.0005 --lr-scheduler inverse_sqrt --warmup-updates 4000 --warmup-init-lr 1e-07 \
|
| 45 |
+
--dropout 0.3 --weight-decay 0.0 \
|
| 46 |
+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
|
| 47 |
+
--max-tokens 3584 \
|
| 48 |
+
--fp16
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Note that the `--fp16` flag requires you have CUDA 9.1 or greater and a Volta GPU or newer.
|
| 52 |
+
|
| 53 |
+
***IMPORTANT:*** You will get better performance by training with big batches and
|
| 54 |
+
increasing the learning rate. If you want to train the above model with big batches
|
| 55 |
+
(assuming your machine has 8 GPUs):
|
| 56 |
+
- add `--update-freq 16` to simulate training on 8x16=128 GPUs
|
| 57 |
+
- increase the learning rate; 0.001 works well for big batches
|
| 58 |
+
|
| 59 |
+
##### 4. Evaluate
|
| 60 |
+
|
| 61 |
+
Now we can evaluate our trained model.
|
| 62 |
+
|
| 63 |
+
Note that the original [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
|
| 64 |
+
paper used a couple tricks to achieve better BLEU scores. We use these same tricks in
|
| 65 |
+
the Scaling NMT paper, so it's important to apply them when reproducing our results.
|
| 66 |
+
|
| 67 |
+
First, use the [average_checkpoints.py](/scripts/average_checkpoints.py) script to
|
| 68 |
+
average the last few checkpoints. Averaging the last 5-10 checkpoints is usually
|
| 69 |
+
good, but you may need to adjust this depending on how long you've trained:
|
| 70 |
+
```bash
|
| 71 |
+
python scripts/average_checkpoints \
|
| 72 |
+
--inputs /path/to/checkpoints \
|
| 73 |
+
--num-epoch-checkpoints 10 \
|
| 74 |
+
--output checkpoint.avg10.pt
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
Next, generate translations using a beam width of 4 and length penalty of 0.6:
|
| 78 |
+
```bash
|
| 79 |
+
fairseq-generate \
|
| 80 |
+
data-bin/wmt16_en_de_bpe32k \
|
| 81 |
+
--path checkpoint.avg10.pt \
|
| 82 |
+
--beam 4 --lenpen 0.6 --remove-bpe > gen.out
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
Finally, we apply the ["compound splitting" script](/scripts/compound_split_bleu.sh) to
|
| 86 |
+
add spaces around dashes. For example "Café-Liebhaber" would become three tokens:
|
| 87 |
+
"Café - Liebhaber". This typically results in larger BLEU scores, but it is not
|
| 88 |
+
appropriate to compare these inflated scores to work which does not include this trick.
|
| 89 |
+
This trick was used in the [original AIAYN code](https://github.com/tensorflow/tensor2tensor/blob/fc9335c0203685cbbfe2b30c92db4352d8f60779/tensor2tensor/utils/get_ende_bleu.sh),
|
| 90 |
+
so we used it in the Scaling NMT paper as well. That said, it's strongly advised to
|
| 91 |
+
report [sacrebleu](https://github.com/mjpost/sacrebleu) scores instead.
|
| 92 |
+
|
| 93 |
+
To compute "compound split" tokenized BLEU (not recommended!):
|
| 94 |
+
```bash
|
| 95 |
+
bash scripts/compound_split_bleu.sh gen.out
|
| 96 |
+
# BLEU4 = 29.29, 60.3/35.0/22.8/15.3 (BP=1.000, ratio=1.004, syslen=64763, reflen=64496)
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
To compute detokenized BLEU with sacrebleu (preferred):
|
| 100 |
+
```bash
|
| 101 |
+
bash scripts/sacrebleu.sh wmt14/full en de gen.out
|
| 102 |
+
# BLEU+case.mixed+lang.en-de+numrefs.1+smooth.exp+test.wmt14/full+tok.13a+version.1.4.3 = 28.6 59.3/34.3/22.1/14.9 (BP = 1.000 ratio = 1.016 hyp_len = 63666 ref_len = 62688)
|
| 103 |
+
```
|
| 104 |
+
|
| 105 |
+
## Citation
|
| 106 |
+
|
| 107 |
+
```bibtex
|
| 108 |
+
@inproceedings{ott2018scaling,
|
| 109 |
+
title = {Scaling Neural Machine Translation},
|
| 110 |
+
author = {Ott, Myle and Edunov, Sergey and Grangier, David and Auli, Michael},
|
| 111 |
+
booktitle = {Proceedings of the Third Conference on Machine Translation (WMT)},
|
| 112 |
+
year = 2018,
|
| 113 |
+
}
|
| 114 |
+
```
|
fairseq-0.10.2/examples/simultaneous_translation/criterions/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 11 |
+
if file.endswith(".py") and not file.startswith("_"):
|
| 12 |
+
criterion_name = file[: file.find(".py")]
|
| 13 |
+
importlib.import_module(
|
| 14 |
+
"examples.simultaneous_translation.criterions." + criterion_name
|
| 15 |
+
)
|
fairseq-0.10.2/examples/simultaneous_translation/criterions/label_smoothed_cross_entropy_latency_augmented.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from examples.simultaneous_translation.utils.latency import LatencyTraining
|
| 7 |
+
from fairseq.criterions import register_criterion
|
| 8 |
+
from fairseq.criterions.label_smoothed_cross_entropy import (
|
| 9 |
+
LabelSmoothedCrossEntropyCriterion,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@register_criterion("latency_augmented_label_smoothed_cross_entropy")
|
| 14 |
+
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion(
|
| 15 |
+
LabelSmoothedCrossEntropyCriterion
|
| 16 |
+
):
|
| 17 |
+
def __init__(self, args, task):
|
| 18 |
+
super().__init__(args, task)
|
| 19 |
+
self.eps = args.label_smoothing
|
| 20 |
+
self.latency_weight_avg = args.latency_weight_avg
|
| 21 |
+
self.latency_weight_avg_type = args.latency_weight_avg_type
|
| 22 |
+
self.latency_weight_var = args.latency_weight_var
|
| 23 |
+
self.latency_weight_var_type = args.latency_weight_var_type
|
| 24 |
+
self.mass_preservation = args.mass_preservation
|
| 25 |
+
self.average_method = args.average_method
|
| 26 |
+
self.latency_train = LatencyTraining(
|
| 27 |
+
self.latency_weight_avg,
|
| 28 |
+
self.latency_weight_var,
|
| 29 |
+
self.latency_weight_avg_type,
|
| 30 |
+
self.latency_weight_var_type,
|
| 31 |
+
self.mass_preservation,
|
| 32 |
+
self.average_method,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def add_args(parser):
|
| 37 |
+
super(
|
| 38 |
+
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
| 39 |
+
LatencyAugmentedLabelSmoothedCrossEntropyCriterion,
|
| 40 |
+
).add_args(parser)
|
| 41 |
+
"""Add criterion-specific arguments to the parser."""
|
| 42 |
+
# fmt: off
|
| 43 |
+
parser.add_argument("--latency-weight-avg", default=0., type=float, metavar='D',
|
| 44 |
+
help="Average loss weight")
|
| 45 |
+
parser.add_argument("--latency-weight-var", default=0., type=float, metavar='D',
|
| 46 |
+
help="Variance loss weight")
|
| 47 |
+
parser.add_argument("--latency-weight-avg-type", default="differentiable_average_lagging",
|
| 48 |
+
help="Statistics for Average loss type")
|
| 49 |
+
parser.add_argument("--latency-weight-var-type", default="variance_delay",
|
| 50 |
+
help="Statistics for variance loss type")
|
| 51 |
+
parser.add_argument("--average-method", default="weighted_average",
|
| 52 |
+
help="Average loss type")
|
| 53 |
+
# fmt: on
|
| 54 |
+
|
| 55 |
+
def compute_loss(self, model, net_output, sample, reduce=True):
|
| 56 |
+
# Compute cross entropy loss first
|
| 57 |
+
loss, nll_loss = super().compute_loss(model, net_output, sample, reduce)
|
| 58 |
+
|
| 59 |
+
# Obtain the expected alignment
|
| 60 |
+
attn_list = [item["alpha"] for item in net_output[-1]["attn_list"]]
|
| 61 |
+
|
| 62 |
+
target_padding_mask = model.get_targets(sample, net_output).eq(self.padding_idx)
|
| 63 |
+
|
| 64 |
+
source_padding_mask = net_output[-1].get("encoder_padding_mask", None)
|
| 65 |
+
|
| 66 |
+
# Get latency loss
|
| 67 |
+
latency_loss = self.latency_train.loss(
|
| 68 |
+
attn_list, source_padding_mask, target_padding_mask
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
loss += latency_loss
|
| 72 |
+
|
| 73 |
+
return loss, nll_loss
|
fairseq-0.10.2/examples/simultaneous_translation/eval/agents/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from fairseq import registry
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
build_agent, register_agent, MONOTONIC_AGENT, _ = registry.setup_registry(
|
| 13 |
+
"--agent-type"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
DEFAULT_EOS = "</s>"
|
| 18 |
+
GET = 0
|
| 19 |
+
SEND = 1
|
| 20 |
+
|
| 21 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 22 |
+
if file.endswith(".py") and not file.startswith("_"):
|
| 23 |
+
module = file[: file.find(".py")]
|
| 24 |
+
importlib.import_module("agents." + module)
|
fairseq-0.10.2/examples/simultaneous_translation/eval/agents/agent.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
from multiprocessing.pool import ThreadPool as Pool
|
| 9 |
+
|
| 10 |
+
from . import DEFAULT_EOS, GET, SEND
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Agent(object):
|
| 14 |
+
"an agent needs to follow this pattern"
|
| 15 |
+
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def init_states(self, *args, **kwargs):
|
| 20 |
+
raise NotImplementedError
|
| 21 |
+
|
| 22 |
+
def update_states(self, states, new_state):
|
| 23 |
+
raise NotImplementedError
|
| 24 |
+
|
| 25 |
+
def finish_eval(self, states, new_state):
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
def policy(self, state):
|
| 29 |
+
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
def reset(self):
|
| 32 |
+
raise NotImplementedError
|
| 33 |
+
|
| 34 |
+
def decode(self, session, low=0, high=100000, num_thread=10):
|
| 35 |
+
corpus_info = session.corpus_info()
|
| 36 |
+
high = min(corpus_info["num_sentences"] - 1, high)
|
| 37 |
+
if low >= high:
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
t0 = time.time()
|
| 41 |
+
if num_thread > 1:
|
| 42 |
+
with Pool(10) as p:
|
| 43 |
+
p.map(
|
| 44 |
+
partial(self._decode_one, session),
|
| 45 |
+
[sent_id for sent_id in range(low, high + 1)],
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
for sent_id in range(low, high + 1):
|
| 49 |
+
self._decode_one(session, sent_id)
|
| 50 |
+
|
| 51 |
+
print(f"Finished {low} to {high} in {time.time() - t0}s")
|
| 52 |
+
|
| 53 |
+
def _decode_one(self, session, sent_id):
|
| 54 |
+
action = {}
|
| 55 |
+
self.reset()
|
| 56 |
+
states = self.init_states()
|
| 57 |
+
while action.get("value", None) != DEFAULT_EOS:
|
| 58 |
+
# take an action
|
| 59 |
+
action = self.policy(states)
|
| 60 |
+
|
| 61 |
+
if action["key"] == GET:
|
| 62 |
+
new_states = session.get_src(sent_id, action["value"])
|
| 63 |
+
states = self.update_states(states, new_states)
|
| 64 |
+
|
| 65 |
+
elif action["key"] == SEND:
|
| 66 |
+
session.send_hypo(sent_id, action["value"])
|
| 67 |
+
print(" ".join(states["tokens"]["tgt"]))
|
fairseq-0.10.2/examples/simultaneous_translation/eval/agents/simul_trans_agent.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from fairseq import checkpoint_utils, tasks, utils
|
| 10 |
+
|
| 11 |
+
from . import DEFAULT_EOS, GET, SEND
|
| 12 |
+
from .agent import Agent
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SimulTransAgent(Agent):
|
| 16 |
+
def __init__(self, args):
|
| 17 |
+
# Load Model
|
| 18 |
+
self.load_model(args)
|
| 19 |
+
|
| 20 |
+
# build word spliter
|
| 21 |
+
self.build_word_splitter(args)
|
| 22 |
+
|
| 23 |
+
self.max_len = args.max_len
|
| 24 |
+
|
| 25 |
+
self.eos = DEFAULT_EOS
|
| 26 |
+
|
| 27 |
+
@staticmethod
|
| 28 |
+
def add_args(parser):
|
| 29 |
+
# fmt: off
|
| 30 |
+
parser.add_argument('--model-path', type=str, required=True,
|
| 31 |
+
help='path to your pretrained model.')
|
| 32 |
+
parser.add_argument("--data-bin", type=str, required=True,
|
| 33 |
+
help="Path of data binary")
|
| 34 |
+
parser.add_argument("--user-dir", type=str, default="example/simultaneous_translation",
|
| 35 |
+
help="User directory for simultaneous translation")
|
| 36 |
+
parser.add_argument("--src-splitter-type", type=str, default=None,
|
| 37 |
+
help="Subword splitter type for source text")
|
| 38 |
+
parser.add_argument("--tgt-splitter-type", type=str, default=None,
|
| 39 |
+
help="Subword splitter type for target text")
|
| 40 |
+
parser.add_argument("--src-splitter-path", type=str, default=None,
|
| 41 |
+
help="Subword splitter model path for source text")
|
| 42 |
+
parser.add_argument("--tgt-splitter-path", type=str, default=None,
|
| 43 |
+
help="Subword splitter model path for target text")
|
| 44 |
+
parser.add_argument("--max-len", type=int, default=150,
|
| 45 |
+
help="Maximum length difference between source and target prediction")
|
| 46 |
+
parser.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
|
| 47 |
+
help='A dictionary used to override model args at generation '
|
| 48 |
+
'that were used during model training')
|
| 49 |
+
# fmt: on
|
| 50 |
+
return parser
|
| 51 |
+
|
| 52 |
+
def load_dictionary(self, task):
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
def load_model(self, args):
|
| 56 |
+
args.user_dir = os.path.join(os.path.dirname(__file__), "..", "..")
|
| 57 |
+
utils.import_user_module(args)
|
| 58 |
+
filename = args.model_path
|
| 59 |
+
if not os.path.exists(filename):
|
| 60 |
+
raise IOError("Model file not found: {}".format(filename))
|
| 61 |
+
|
| 62 |
+
state = checkpoint_utils.load_checkpoint_to_cpu(
|
| 63 |
+
filename, json.loads(args.model_overrides)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
saved_args = state["args"]
|
| 67 |
+
saved_args.data = args.data_bin
|
| 68 |
+
|
| 69 |
+
task = tasks.setup_task(saved_args)
|
| 70 |
+
|
| 71 |
+
# build model for ensemble
|
| 72 |
+
self.model = task.build_model(saved_args)
|
| 73 |
+
self.model.load_state_dict(state["model"], strict=True)
|
| 74 |
+
|
| 75 |
+
# Set dictionary
|
| 76 |
+
self.load_dictionary(task)
|
| 77 |
+
|
| 78 |
+
def init_states(self):
|
| 79 |
+
return {
|
| 80 |
+
"indices": {"src": [], "tgt": []},
|
| 81 |
+
"tokens": {"src": [], "tgt": []},
|
| 82 |
+
"segments": {"src": [], "tgt": []},
|
| 83 |
+
"steps": {"src": 0, "tgt": 0},
|
| 84 |
+
"finished": False,
|
| 85 |
+
"finish_read": False,
|
| 86 |
+
"model_states": {},
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
def update_states(self, states, new_state):
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
|
| 92 |
+
def policy(self, states):
|
| 93 |
+
# Read and Write policy
|
| 94 |
+
action = None
|
| 95 |
+
|
| 96 |
+
while action is None:
|
| 97 |
+
if states["finished"]:
|
| 98 |
+
# Finish the hypo by sending eos to server
|
| 99 |
+
return self.finish_action()
|
| 100 |
+
|
| 101 |
+
# Model make decision given current states
|
| 102 |
+
decision = self.model.decision_from_states(states)
|
| 103 |
+
|
| 104 |
+
if decision == 0 and not self.finish_read(states):
|
| 105 |
+
# READ
|
| 106 |
+
action = self.read_action(states)
|
| 107 |
+
else:
|
| 108 |
+
# WRITE
|
| 109 |
+
action = self.write_action(states)
|
| 110 |
+
|
| 111 |
+
# None means we make decision again but not sending server anything
|
| 112 |
+
# This happened when read a bufffered token
|
| 113 |
+
# Or predict a subword
|
| 114 |
+
return action
|
| 115 |
+
|
| 116 |
+
def finish_read(self, states):
|
| 117 |
+
raise NotImplementedError
|
| 118 |
+
|
| 119 |
+
def write_action(self, states):
|
| 120 |
+
token, index = self.model.predict_from_states(states)
|
| 121 |
+
|
| 122 |
+
if (
|
| 123 |
+
index == self.dict["tgt"].eos()
|
| 124 |
+
or len(states["tokens"]["tgt"]) > self.max_len
|
| 125 |
+
):
|
| 126 |
+
# Finish this sentence is predict EOS
|
| 127 |
+
states["finished"] = True
|
| 128 |
+
end_idx_last_full_word = self._target_length(states)
|
| 129 |
+
|
| 130 |
+
else:
|
| 131 |
+
states["tokens"]["tgt"] += [token]
|
| 132 |
+
end_idx_last_full_word = self.word_splitter["tgt"].end_idx_last_full_word(
|
| 133 |
+
states["tokens"]["tgt"]
|
| 134 |
+
)
|
| 135 |
+
self._append_indices(states, [index], "tgt")
|
| 136 |
+
|
| 137 |
+
if end_idx_last_full_word > states["steps"]["tgt"]:
|
| 138 |
+
# Only sent detokenized full words to the server
|
| 139 |
+
word = self.word_splitter["tgt"].merge(
|
| 140 |
+
states["tokens"]["tgt"][states["steps"]["tgt"] : end_idx_last_full_word]
|
| 141 |
+
)
|
| 142 |
+
states["steps"]["tgt"] = end_idx_last_full_word
|
| 143 |
+
states["segments"]["tgt"] += [word]
|
| 144 |
+
|
| 145 |
+
return {"key": SEND, "value": word}
|
| 146 |
+
else:
|
| 147 |
+
return None
|
| 148 |
+
|
| 149 |
+
def read_action(self, states):
|
| 150 |
+
return {"key": GET, "value": None}
|
| 151 |
+
|
| 152 |
+
def finish_action(self):
|
| 153 |
+
return {"key": SEND, "value": DEFAULT_EOS}
|
| 154 |
+
|
| 155 |
+
def reset(self):
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
def finish_eval(self, states, new_state):
|
| 159 |
+
if len(new_state) == 0 and len(states["indices"]["src"]) == 0:
|
| 160 |
+
return True
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
def _append_indices(self, states, new_indices, key):
|
| 164 |
+
states["indices"][key] += new_indices
|
| 165 |
+
|
| 166 |
+
def _target_length(self, states):
|
| 167 |
+
return len(states["tokens"]["tgt"])
|
fairseq-0.10.2/examples/simultaneous_translation/eval/agents/simul_trans_text_agent.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import DEFAULT_EOS, GET, register_agent
|
| 7 |
+
from .simul_trans_agent import SimulTransAgent
|
| 8 |
+
from .word_splitter import SPLITTER_DICT
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@register_agent("simul_trans_text")
|
| 12 |
+
class SimulTransTextAgent(SimulTransAgent):
|
| 13 |
+
def build_word_splitter(self, args):
|
| 14 |
+
self.word_splitter = {}
|
| 15 |
+
|
| 16 |
+
self.word_splitter["src"] = SPLITTER_DICT[args.src_splitter_type](
|
| 17 |
+
getattr(args, f"src_splitter_path")
|
| 18 |
+
)
|
| 19 |
+
self.word_splitter["tgt"] = SPLITTER_DICT[args.tgt_splitter_type](
|
| 20 |
+
getattr(args, f"tgt_splitter_path")
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def load_dictionary(self, task):
|
| 24 |
+
self.dict = {}
|
| 25 |
+
self.dict["tgt"] = task.target_dictionary
|
| 26 |
+
self.dict["src"] = task.source_dictionary
|
| 27 |
+
|
| 28 |
+
def update_states(self, states, new_state):
|
| 29 |
+
if states["finish_read"]:
|
| 30 |
+
return states
|
| 31 |
+
|
| 32 |
+
new_word = new_state["segment"]
|
| 33 |
+
|
| 34 |
+
# Split words and index the token
|
| 35 |
+
if new_word not in [DEFAULT_EOS]:
|
| 36 |
+
tokens = self.word_splitter["src"].split(new_word)
|
| 37 |
+
# Get indices from dictionary
|
| 38 |
+
# You can change to you own dictionary
|
| 39 |
+
indices = (
|
| 40 |
+
self.dict["src"]
|
| 41 |
+
.encode_line(
|
| 42 |
+
tokens,
|
| 43 |
+
line_tokenizer=lambda x: x,
|
| 44 |
+
add_if_not_exist=False,
|
| 45 |
+
append_eos=False,
|
| 46 |
+
)
|
| 47 |
+
.tolist()
|
| 48 |
+
)
|
| 49 |
+
else:
|
| 50 |
+
tokens = [new_word]
|
| 51 |
+
indices = [self.dict["src"].eos()]
|
| 52 |
+
states["finish_read"] = True
|
| 53 |
+
|
| 54 |
+
# Update states
|
| 55 |
+
states["segments"]["src"] += [new_word]
|
| 56 |
+
states["tokens"]["src"] += tokens
|
| 57 |
+
self._append_indices(states, indices, "src")
|
| 58 |
+
|
| 59 |
+
return states
|
| 60 |
+
|
| 61 |
+
def read_action(self, states):
|
| 62 |
+
# Increase source step by one
|
| 63 |
+
states["steps"]["src"] += 1
|
| 64 |
+
|
| 65 |
+
# At leat one word is read
|
| 66 |
+
if len(states["tokens"]["src"]) == 0:
|
| 67 |
+
return {"key": GET, "value": None}
|
| 68 |
+
|
| 69 |
+
# Only request new word if there is no buffered tokens
|
| 70 |
+
if len(states["tokens"]["src"]) <= states["steps"]["src"]:
|
| 71 |
+
return {"key": GET, "value": None}
|
| 72 |
+
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
def finish_read(self, states):
|
| 76 |
+
# The first means all segments (full words) has been read from server
|
| 77 |
+
# The second means all tokens (subwords) has been read locally
|
| 78 |
+
return (
|
| 79 |
+
states["finish_read"]
|
| 80 |
+
and len(states["tokens"]["src"]) == states["steps"]["src"]
|
| 81 |
+
)
|
fairseq-0.10.2/examples/simultaneous_translation/eval/scorers/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from fairseq import registry
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
(build_scorer, register_scorer, SCORER_REGISTRIES, _) = registry.setup_registry(
|
| 13 |
+
"--scorer-type"
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 17 |
+
if file.endswith(".py") and not file.startswith("_"):
|
| 18 |
+
module = file[: file.find(".py")]
|
| 19 |
+
importlib.import_module("scorers." + module)
|
fairseq-0.10.2/examples/simultaneous_translation/eval/scorers/scorer.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import os
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
|
| 10 |
+
from examples.simultaneous_translation.eval.eval_latency import LatencyScorer
|
| 11 |
+
from vizseq.scorers.bleu import BLEUScorer
|
| 12 |
+
from vizseq.scorers.meteor import METEORScorer
|
| 13 |
+
from vizseq.scorers.ter import TERScorer
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
DEFAULT_EOS = "</s>"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SimulScorer(object):
|
| 20 |
+
def __init__(self, args):
|
| 21 |
+
self.tokenizer = args.tokenizer
|
| 22 |
+
self.output_dir = args.output
|
| 23 |
+
if args.output is not None:
|
| 24 |
+
self.output_files = {
|
| 25 |
+
"text": os.path.join(args.output, "text"),
|
| 26 |
+
"delay": os.path.join(args.output, "delay"),
|
| 27 |
+
"scores": os.path.join(args.output, "scores"),
|
| 28 |
+
}
|
| 29 |
+
else:
|
| 30 |
+
self.output_files = None
|
| 31 |
+
self.eos = DEFAULT_EOS
|
| 32 |
+
self.data = {"tgt": []}
|
| 33 |
+
self.reset()
|
| 34 |
+
|
| 35 |
+
def get_info(self):
|
| 36 |
+
return {"num_sentences": len(self)}
|
| 37 |
+
|
| 38 |
+
@staticmethod
|
| 39 |
+
def add_args(parser):
|
| 40 |
+
# fmt: off
|
| 41 |
+
parser.add_argument('--src-file', type=str, required=True,
|
| 42 |
+
help='Source input file')
|
| 43 |
+
parser.add_argument('--tgt-file', type=str, required=True,
|
| 44 |
+
help='Target reference file')
|
| 45 |
+
parser.add_argument('--tokenizer', default="13a", choices=["none", "13a"],
|
| 46 |
+
help='Tokenizer used for sacrebleu')
|
| 47 |
+
parser.add_argument('--output', type=str, default=None,
|
| 48 |
+
help='Path for output directory')
|
| 49 |
+
# fmt: on
|
| 50 |
+
|
| 51 |
+
def send_src(self, sent_id, *args):
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
def recv_hyp(self, sent_id, list_of_tokens):
|
| 55 |
+
for token in list_of_tokens:
|
| 56 |
+
self.translations[sent_id].append((token, self.steps[sent_id]))
|
| 57 |
+
|
| 58 |
+
def reset(self):
|
| 59 |
+
self.steps = defaultdict(int)
|
| 60 |
+
self.translations = defaultdict(list)
|
| 61 |
+
|
| 62 |
+
def src_lengths(self):
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
def score(self):
|
| 66 |
+
translations = []
|
| 67 |
+
delays = []
|
| 68 |
+
for i in range(1 + max(self.translations.keys())):
|
| 69 |
+
translations += [" ".join(t[0] for t in self.translations[i][:-1])]
|
| 70 |
+
delays += [[t[1] for t in self.translations[i]]]
|
| 71 |
+
|
| 72 |
+
bleu_score = BLEUScorer(
|
| 73 |
+
sent_level=False,
|
| 74 |
+
corpus_level=True,
|
| 75 |
+
extra_args={"bleu_tokenizer": self.tokenizer},
|
| 76 |
+
).score(translations, [self.data["tgt"]])
|
| 77 |
+
|
| 78 |
+
ter_score = TERScorer(sent_level=False, corpus_level=True).score(
|
| 79 |
+
translations, [self.data["tgt"]]
|
| 80 |
+
)
|
| 81 |
+
meteor_score = METEORScorer(sent_level=False, corpus_level=True).score(
|
| 82 |
+
translations, [self.data["tgt"]]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
latency_score = LatencyScorer().score(
|
| 86 |
+
[
|
| 87 |
+
{"src_len": src_len, "delays": delay}
|
| 88 |
+
for src_len, delay in zip(self.src_lengths(), delays)
|
| 89 |
+
],
|
| 90 |
+
start_from_zero=False,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
scores = {
|
| 94 |
+
"BLEU": bleu_score[0],
|
| 95 |
+
"TER": ter_score[0],
|
| 96 |
+
"METEOR": meteor_score[0],
|
| 97 |
+
"DAL": latency_score["differentiable_average_lagging"],
|
| 98 |
+
"AL": latency_score["average_lagging"],
|
| 99 |
+
"AP": latency_score["average_proportion"],
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
if self.output_files is not None:
|
| 103 |
+
try:
|
| 104 |
+
os.makedirs(self.output_dir, exist_ok=True)
|
| 105 |
+
self.write_results_to_file(translations, delays, scores)
|
| 106 |
+
except BaseException as be:
|
| 107 |
+
print(f"Failed to write results to {self.output_dir}.")
|
| 108 |
+
print(be)
|
| 109 |
+
print("Skip writing predictions")
|
| 110 |
+
|
| 111 |
+
return scores
|
| 112 |
+
|
| 113 |
+
def write_results_to_file(self, translations, delays, scores):
|
| 114 |
+
if self.output_files["text"] is not None:
|
| 115 |
+
with open(self.output_files["text"], "w") as f:
|
| 116 |
+
for line in translations:
|
| 117 |
+
f.write(line + "\n")
|
| 118 |
+
|
| 119 |
+
if self.output_files["delay"] is not None:
|
| 120 |
+
with open(self.output_files["delay"], "w") as f:
|
| 121 |
+
for i, delay in enumerate(delays):
|
| 122 |
+
f.write(
|
| 123 |
+
json.dumps({"src_len": self.src_lengths()[i], "delays": delay})
|
| 124 |
+
+ "\n"
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
with open(self.output_files["scores"], "w") as f:
|
| 128 |
+
for key, value in scores.items():
|
| 129 |
+
f.write(f"{key}, {value}\n")
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def _load_text_file(cls, file, split=False):
|
| 133 |
+
with open(file) as f:
|
| 134 |
+
if split:
|
| 135 |
+
return [r.strip().split() for r in f]
|
| 136 |
+
else:
|
| 137 |
+
return [r.strip() for r in f]
|
| 138 |
+
|
| 139 |
+
@classmethod
|
| 140 |
+
def _load_text_from_json(cls, file):
|
| 141 |
+
list_to_return = []
|
| 142 |
+
with open(file) as f:
|
| 143 |
+
content = json.load(f)
|
| 144 |
+
for item in content["utts"].values():
|
| 145 |
+
list_to_return.append(item["output"]["text"].strip())
|
| 146 |
+
return list_to_return
|
| 147 |
+
|
| 148 |
+
@classmethod
|
| 149 |
+
def _load_wav_info_from_json(cls, file):
|
| 150 |
+
list_to_return = []
|
| 151 |
+
with open(file) as f:
|
| 152 |
+
content = json.load(f)
|
| 153 |
+
for item in content["utts"].values():
|
| 154 |
+
list_to_return.append(
|
| 155 |
+
{
|
| 156 |
+
"path": item["input"]["path"].strip(),
|
| 157 |
+
"length": item["input"]["length_ms"],
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
return list_to_return
|
| 161 |
+
|
| 162 |
+
@classmethod
|
| 163 |
+
def _load_wav_info_from_list(cls, file):
|
| 164 |
+
list_to_return = []
|
| 165 |
+
with open(file) as f:
|
| 166 |
+
for line in f:
|
| 167 |
+
list_to_return.append(
|
| 168 |
+
{
|
| 169 |
+
"path": line.strip(),
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
return list_to_return
|
| 173 |
+
|
| 174 |
+
def __len__(self):
|
| 175 |
+
return len(self.data["tgt"])
|
fairseq-0.10.2/examples/simultaneous_translation/eval/scorers/text_scorer.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import register_scorer
|
| 7 |
+
from .scorer import SimulScorer
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@register_scorer("text")
|
| 11 |
+
class SimulTextScorer(SimulScorer):
|
| 12 |
+
def __init__(self, args):
|
| 13 |
+
super().__init__(args)
|
| 14 |
+
self.data = {
|
| 15 |
+
"src": self._load_text_file(args.src_file, split=True),
|
| 16 |
+
"tgt": self._load_text_file(args.tgt_file, split=False),
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
def send_src(self, sent_id, *args):
|
| 20 |
+
if self.steps[sent_id] >= len(self.data["src"][sent_id]):
|
| 21 |
+
dict_to_return = {
|
| 22 |
+
"sent_id": sent_id,
|
| 23 |
+
"segment_id": self.steps[sent_id],
|
| 24 |
+
"segment": self.eos,
|
| 25 |
+
}
|
| 26 |
+
# Consider EOS
|
| 27 |
+
self.steps[sent_id] = len(self.data["src"][sent_id]) + 1
|
| 28 |
+
else:
|
| 29 |
+
dict_to_return = {
|
| 30 |
+
"sent_id": sent_id,
|
| 31 |
+
"segment_id": self.steps[sent_id],
|
| 32 |
+
"segment": self.data["src"][sent_id][self.steps[sent_id]],
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
self.steps[sent_id] += 1
|
| 36 |
+
|
| 37 |
+
return dict_to_return
|
| 38 |
+
|
| 39 |
+
def src_lengths(self):
|
| 40 |
+
# +1 for eos
|
| 41 |
+
return [len(sent) + 1 for sent in self.data["src"]]
|
fairseq-0.10.2/examples/speech_recognition/criterions/ASG_loss.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the MIT license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from examples.speech_recognition.data.replabels import pack_replabels
|
| 10 |
+
from fairseq import utils
|
| 11 |
+
from fairseq.criterions import FairseqCriterion, register_criterion
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@register_criterion("asg_loss")
|
| 15 |
+
class ASGCriterion(FairseqCriterion):
|
| 16 |
+
@staticmethod
|
| 17 |
+
def add_args(parser):
|
| 18 |
+
group = parser.add_argument_group("ASG Loss")
|
| 19 |
+
group.add_argument(
|
| 20 |
+
"--asg-transitions-init",
|
| 21 |
+
help="initial diagonal value of transition matrix",
|
| 22 |
+
type=float,
|
| 23 |
+
default=0.0,
|
| 24 |
+
)
|
| 25 |
+
group.add_argument(
|
| 26 |
+
"--max-replabel", help="maximum # of replabels", type=int, default=2
|
| 27 |
+
)
|
| 28 |
+
group.add_argument(
|
| 29 |
+
"--linseg-updates",
|
| 30 |
+
help="# of training updates to use LinSeg initialization",
|
| 31 |
+
type=int,
|
| 32 |
+
default=0,
|
| 33 |
+
)
|
| 34 |
+
group.add_argument(
|
| 35 |
+
"--hide-linseg-messages",
|
| 36 |
+
help="hide messages about LinSeg initialization",
|
| 37 |
+
action="store_true",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
task,
|
| 43 |
+
silence_token,
|
| 44 |
+
asg_transitions_init,
|
| 45 |
+
max_replabel,
|
| 46 |
+
linseg_updates,
|
| 47 |
+
hide_linseg_messages,
|
| 48 |
+
):
|
| 49 |
+
from wav2letter.criterion import ASGLoss, CriterionScaleMode
|
| 50 |
+
|
| 51 |
+
super().__init__(task)
|
| 52 |
+
self.tgt_dict = task.target_dictionary
|
| 53 |
+
self.eos = self.tgt_dict.eos()
|
| 54 |
+
self.silence = (
|
| 55 |
+
self.tgt_dict.index(silence_token)
|
| 56 |
+
if silence_token in self.tgt_dict
|
| 57 |
+
else None
|
| 58 |
+
)
|
| 59 |
+
self.max_replabel = max_replabel
|
| 60 |
+
|
| 61 |
+
num_labels = len(self.tgt_dict)
|
| 62 |
+
self.asg = ASGLoss(num_labels, scale_mode=CriterionScaleMode.TARGET_SZ_SQRT)
|
| 63 |
+
self.asg.trans = torch.nn.Parameter(
|
| 64 |
+
asg_transitions_init * torch.eye(num_labels), requires_grad=True
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.linseg_progress = torch.nn.Parameter(
|
| 68 |
+
torch.tensor([0], dtype=torch.int), requires_grad=False
|
| 69 |
+
)
|
| 70 |
+
self.linseg_maximum = linseg_updates
|
| 71 |
+
self.linseg_message_state = "none" if hide_linseg_messages else "start"
|
| 72 |
+
|
| 73 |
+
@classmethod
|
| 74 |
+
def build_criterion(cls, args, task):
|
| 75 |
+
return cls(
|
| 76 |
+
task,
|
| 77 |
+
args.silence_token,
|
| 78 |
+
args.asg_transitions_init,
|
| 79 |
+
args.max_replabel,
|
| 80 |
+
args.linseg_updates,
|
| 81 |
+
args.hide_linseg_messages,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def linseg_step(self):
|
| 85 |
+
if not self.training:
|
| 86 |
+
return False
|
| 87 |
+
if self.linseg_progress.item() < self.linseg_maximum:
|
| 88 |
+
if self.linseg_message_state == "start":
|
| 89 |
+
print("| using LinSeg to initialize ASG")
|
| 90 |
+
self.linseg_message_state = "finish"
|
| 91 |
+
self.linseg_progress.add_(1)
|
| 92 |
+
return True
|
| 93 |
+
elif self.linseg_message_state == "finish":
|
| 94 |
+
print("| finished LinSeg initialization")
|
| 95 |
+
self.linseg_message_state = "none"
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def replace_eos_with_silence(self, tgt):
|
| 99 |
+
if tgt[-1] != self.eos:
|
| 100 |
+
return tgt
|
| 101 |
+
elif self.silence is None or (len(tgt) > 1 and tgt[-2] == self.silence):
|
| 102 |
+
return tgt[:-1]
|
| 103 |
+
else:
|
| 104 |
+
return tgt[:-1] + [self.silence]
|
| 105 |
+
|
| 106 |
+
def forward(self, model, sample, reduce=True):
|
| 107 |
+
"""Compute the loss for the given sample.
|
| 108 |
+
|
| 109 |
+
Returns a tuple with three elements:
|
| 110 |
+
1) the loss
|
| 111 |
+
2) the sample size, which is used as the denominator for the gradient
|
| 112 |
+
3) logging outputs to display while training
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
net_output = model(**sample["net_input"])
|
| 116 |
+
emissions = net_output["encoder_out"].transpose(0, 1).contiguous()
|
| 117 |
+
B = emissions.size(0)
|
| 118 |
+
T = emissions.size(1)
|
| 119 |
+
device = emissions.device
|
| 120 |
+
|
| 121 |
+
target = torch.IntTensor(B, T)
|
| 122 |
+
target_size = torch.IntTensor(B)
|
| 123 |
+
using_linseg = self.linseg_step()
|
| 124 |
+
|
| 125 |
+
for b in range(B):
|
| 126 |
+
initial_target_size = sample["target_lengths"][b].item()
|
| 127 |
+
if initial_target_size == 0:
|
| 128 |
+
raise ValueError("target size cannot be zero")
|
| 129 |
+
|
| 130 |
+
tgt = sample["target"][b, :initial_target_size].tolist()
|
| 131 |
+
tgt = self.replace_eos_with_silence(tgt)
|
| 132 |
+
tgt = pack_replabels(tgt, self.tgt_dict, self.max_replabel)
|
| 133 |
+
tgt = tgt[:T]
|
| 134 |
+
|
| 135 |
+
if using_linseg:
|
| 136 |
+
tgt = [tgt[t * len(tgt) // T] for t in range(T)]
|
| 137 |
+
|
| 138 |
+
target[b][: len(tgt)] = torch.IntTensor(tgt)
|
| 139 |
+
target_size[b] = len(tgt)
|
| 140 |
+
|
| 141 |
+
loss = self.asg.forward(emissions, target.to(device), target_size.to(device))
|
| 142 |
+
|
| 143 |
+
if reduce:
|
| 144 |
+
loss = torch.sum(loss)
|
| 145 |
+
|
| 146 |
+
sample_size = (
|
| 147 |
+
sample["target"].size(0) if self.args.sentence_avg else sample["ntokens"]
|
| 148 |
+
)
|
| 149 |
+
logging_output = {
|
| 150 |
+
"loss": utils.item(loss.data) if reduce else loss.data,
|
| 151 |
+
"ntokens": sample["ntokens"],
|
| 152 |
+
"nsentences": sample["target"].size(0),
|
| 153 |
+
"sample_size": sample_size,
|
| 154 |
+
}
|
| 155 |
+
return loss, sample_size, logging_output
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def aggregate_logging_outputs(logging_outputs):
|
| 159 |
+
"""Aggregate logging outputs from data parallel training."""
|
| 160 |
+
loss_sum = sum(log.get("loss", 0) for log in logging_outputs)
|
| 161 |
+
ntokens = sum(log.get("ntokens", 0) for log in logging_outputs)
|
| 162 |
+
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs)
|
| 163 |
+
sample_size = sum(log.get("sample_size", 0) for log in logging_outputs)
|
| 164 |
+
agg_output = {
|
| 165 |
+
"loss": loss_sum / nsentences,
|
| 166 |
+
"ntokens": ntokens,
|
| 167 |
+
"nsentences": nsentences,
|
| 168 |
+
"sample_size": sample_size,
|
| 169 |
+
}
|
| 170 |
+
return agg_output
|