Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- fairseq-0.10.2/.github/ISSUE_TEMPLATE.md +3 -0
- fairseq-0.10.2/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
- fairseq-0.10.2/.github/ISSUE_TEMPLATE/documentation.md +15 -0
- fairseq-0.10.2/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
- fairseq-0.10.2/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
- fairseq-0.10.2/config/config.yaml +7 -0
- fairseq-0.10.2/config/config_eval_lm.yaml +7 -0
- fairseq-0.10.2/config/criterion/adaptive_loss.yaml +3 -0
- fairseq-0.10.2/config/criterion/cross_entropy.yaml +3 -0
- fairseq-0.10.2/config/lr_scheduler/cosine.yaml +7 -0
- fairseq-0.10.2/config/lr_scheduler/inverse_sqrt.yaml +3 -0
- fairseq-0.10.2/config/model/transformer_lm.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_baevski_gbw.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_baevski_wiki103.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_big.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_gbw.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_gpt.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_gpt2_big.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_gpt2_medium.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_gpt2_small.yaml +36 -0
- fairseq-0.10.2/config/model/transformer_lm_wiki103.yaml +36 -0
- fairseq-0.10.2/config/optimizer/adam.yaml +5 -0
- fairseq-0.10.2/config/optimizer/nag.yaml +3 -0
- fairseq-0.10.2/config/params/eval_lm_params.yaml +105 -0
- fairseq-0.10.2/config/params/training_params.yaml +95 -0
- fairseq-0.10.2/config/task/language_modeling.yaml +10 -0
- fairseq-0.10.2/examples/noisychannel/README.md +72 -0
- fairseq-0.10.2/examples/noisychannel/__init__.py +6 -0
- fairseq-0.10.2/examples/noisychannel/rerank.py +422 -0
- fairseq-0.10.2/examples/noisychannel/rerank_generate.py +397 -0
- fairseq-0.10.2/examples/noisychannel/rerank_options.py +149 -0
- fairseq-0.10.2/examples/noisychannel/rerank_score_bw.py +143 -0
- fairseq-0.10.2/examples/noisychannel/rerank_tune.py +102 -0
- fairseq-0.10.2/examples/noisychannel/rerank_utils.py +850 -0
- fairseq-0.10.2/examples/paraphraser/README.md +46 -0
- fairseq-0.10.2/examples/paraphraser/paraphrase.py +85 -0
- fairseq-0.10.2/examples/simultaneous_translation/README.md +106 -0
- fairseq-0.10.2/examples/simultaneous_translation/__init__.py +6 -0
- fairseq-0.10.2/examples/simultaneous_translation/docs/baseline.md +178 -0
- fairseq-0.10.2/examples/simultaneous_translation/docs/evaluation.md +115 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/__init__.py +4 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/agents/word_splitter.py +91 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/client.py +100 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/eval_latency.py +78 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/evaluate.py +81 -0
- fairseq-0.10.2/examples/simultaneous_translation/eval/server.py +89 -0
- fairseq-0.10.2/examples/simultaneous_translation/models/__init__.py +15 -0
- fairseq-0.10.2/examples/simultaneous_translation/models/transformer_monotonic_attention.py +322 -0
- fairseq-0.10.2/examples/simultaneous_translation/modules/__init__.py +24 -0
- fairseq-0.10.2/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +622 -0
fairseq-0.10.2/.github/ISSUE_TEMPLATE.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
|
| 2 |
+
|
| 3 |
+
Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
|
fairseq-0.10.2/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 🐛 Bug Report
|
| 3 |
+
about: Submit a bug report to help us improve
|
| 4 |
+
labels: 'bug, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🐛 Bug
|
| 8 |
+
|
| 9 |
+
<!-- A clear and concise description of what the bug is. -->
|
| 10 |
+
|
| 11 |
+
### To Reproduce
|
| 12 |
+
|
| 13 |
+
Steps to reproduce the behavior (**always include the command you ran**):
|
| 14 |
+
|
| 15 |
+
1. Run cmd '....'
|
| 16 |
+
2. See error
|
| 17 |
+
|
| 18 |
+
<!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
#### Code sample
|
| 22 |
+
<!-- Ideally attach a minimal code sample to reproduce the decried issue.
|
| 23 |
+
Minimal means having the shortest code but still preserving the bug. -->
|
| 24 |
+
|
| 25 |
+
### Expected behavior
|
| 26 |
+
|
| 27 |
+
<!-- A clear and concise description of what you expected to happen. -->
|
| 28 |
+
|
| 29 |
+
### Environment
|
| 30 |
+
|
| 31 |
+
- fairseq Version (e.g., 1.0 or master):
|
| 32 |
+
- PyTorch Version (e.g., 1.0)
|
| 33 |
+
- OS (e.g., Linux):
|
| 34 |
+
- How you installed fairseq (`pip`, source):
|
| 35 |
+
- Build command you used (if compiling from source):
|
| 36 |
+
- Python version:
|
| 37 |
+
- CUDA/cuDNN version:
|
| 38 |
+
- GPU models and configuration:
|
| 39 |
+
- Any other relevant information:
|
| 40 |
+
|
| 41 |
+
### Additional context
|
| 42 |
+
|
| 43 |
+
<!-- Add any other context about the problem here. -->
|
fairseq-0.10.2/.github/ISSUE_TEMPLATE/documentation.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 📚 Documentation/Typos
|
| 3 |
+
about: Report an issue related to documentation or a typo
|
| 4 |
+
labels: 'documentation, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 📚 Documentation
|
| 8 |
+
|
| 9 |
+
For typos and doc fixes, please go ahead and:
|
| 10 |
+
|
| 11 |
+
1. Create an issue.
|
| 12 |
+
2. Fix the typo.
|
| 13 |
+
3. Submit a PR.
|
| 14 |
+
|
| 15 |
+
Thanks!
|
fairseq-0.10.2/.github/ISSUE_TEMPLATE/feature_request.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: 🚀 Feature Request
|
| 3 |
+
about: Submit a proposal/request for a new feature
|
| 4 |
+
labels: 'enhancement, help wanted, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## 🚀 Feature Request
|
| 8 |
+
<!-- A clear and concise description of the feature proposal -->
|
| 9 |
+
|
| 10 |
+
### Motivation
|
| 11 |
+
|
| 12 |
+
<!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
|
| 13 |
+
|
| 14 |
+
### Pitch
|
| 15 |
+
|
| 16 |
+
<!-- A clear and concise description of what you want to happen. -->
|
| 17 |
+
|
| 18 |
+
### Alternatives
|
| 19 |
+
|
| 20 |
+
<!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
|
| 21 |
+
|
| 22 |
+
### Additional context
|
| 23 |
+
|
| 24 |
+
<!-- Add any other context or screenshots about the feature request here. -->
|
fairseq-0.10.2/.github/ISSUE_TEMPLATE/how-to-question.md
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: ❓ Questions/Help
|
| 3 |
+
about: If you have questions, please first search existing issues and docs
|
| 4 |
+
labels: 'question, needs triage'
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## ❓ Questions and Help
|
| 8 |
+
|
| 9 |
+
### Before asking:
|
| 10 |
+
1. search the issues.
|
| 11 |
+
2. search the docs.
|
| 12 |
+
|
| 13 |
+
<!-- If you still can't find what you need: -->
|
| 14 |
+
|
| 15 |
+
#### What is your question?
|
| 16 |
+
|
| 17 |
+
#### Code
|
| 18 |
+
|
| 19 |
+
<!-- Please paste a code snippet if your question requires it! -->
|
| 20 |
+
|
| 21 |
+
#### What have you tried?
|
| 22 |
+
|
| 23 |
+
#### What's your environment?
|
| 24 |
+
|
| 25 |
+
- fairseq Version (e.g., 1.0 or master):
|
| 26 |
+
- PyTorch Version (e.g., 1.0)
|
| 27 |
+
- OS (e.g., Linux):
|
| 28 |
+
- How you installed fairseq (`pip`, source):
|
| 29 |
+
- Build command you used (if compiling from source):
|
| 30 |
+
- Python version:
|
| 31 |
+
- CUDA/cuDNN version:
|
| 32 |
+
- GPU models and configuration:
|
| 33 |
+
- Any other relevant information:
|
fairseq-0.10.2/config/config.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- params: training_params
|
| 3 |
+
- task: language_modeling
|
| 4 |
+
- model: transformer_lm
|
| 5 |
+
- criterion: cross_entropy
|
| 6 |
+
- optimizer: adam
|
| 7 |
+
- lr_scheduler: inverse_sqrt
|
fairseq-0.10.2/config/config_eval_lm.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- params: eval_lm_params
|
| 3 |
+
- task: language_modeling
|
| 4 |
+
- model: transformer_lm
|
| 5 |
+
- criterion: cross_entropy
|
| 6 |
+
- optimizer: adam
|
| 7 |
+
- lr_scheduler: inverse_sqrt
|
fairseq-0.10.2/config/criterion/adaptive_loss.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
sentence_avg: ${params.optimization.sentence_avg}
|
| 3 |
+
ddp_backend: ${params.distributed_training.ddp_backend}
|
fairseq-0.10.2/config/criterion/cross_entropy.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
sentence_avg: ${params.optimization.sentence_avg}
|
| 3 |
+
ddp_backend: ${params.distributed_training.ddp_backend}
|
fairseq-0.10.2/config/lr_scheduler/cosine.yaml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
warmup_updates: 0
|
| 3 |
+
warmup_init_lr: -1
|
| 4 |
+
max_lr: 1.0
|
| 5 |
+
t_mult: 1.0
|
| 6 |
+
lr_period_updates: -1
|
| 7 |
+
lr_shrink: 0.1
|
fairseq-0.10.2/config/lr_scheduler/inverse_sqrt.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
warmup_updates: 4000
|
| 3 |
+
warmup_init_lr: -1
|
fairseq-0.10.2/config/model/transformer_lm.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "relu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.0
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 512
|
| 8 |
+
decoder_output_dim: 512
|
| 9 |
+
decoder_input_dim: 512
|
| 10 |
+
decoder_ffn_embed_dim: 2048
|
| 11 |
+
decoder_layers: 6
|
| 12 |
+
decoder_attention_heads: 8
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: false
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_baevski_gbw.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "relu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 512
|
| 8 |
+
decoder_output_dim: 512
|
| 9 |
+
decoder_input_dim: 512
|
| 10 |
+
decoder_ffn_embed_dim: 4096
|
| 11 |
+
decoder_layers: 12
|
| 12 |
+
decoder_attention_heads: 16
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: true
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_baevski_wiki103.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "relu"
|
| 3 |
+
dropout: 0.3
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.1
|
| 6 |
+
relu_dropout: 0.1
|
| 7 |
+
decoder_embed_dim: 1024
|
| 8 |
+
decoder_output_dim: 1024
|
| 9 |
+
decoder_input_dim: 1024
|
| 10 |
+
decoder_ffn_embed_dim: 4096
|
| 11 |
+
decoder_layers: 16
|
| 12 |
+
decoder_attention_heads: 8
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: true
|
| 15 |
+
adaptive_softmax_cutoff: "20000,60000"
|
| 16 |
+
adaptive_softmax_dropout: 0.2
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: true
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: "20000,60000"
|
| 27 |
+
tie_adaptive_weights: true
|
| 28 |
+
tie_adaptive_proj: true
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_big.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "relu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.0
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 1024
|
| 8 |
+
decoder_output_dim: 1024
|
| 9 |
+
decoder_input_dim: 1024
|
| 10 |
+
decoder_ffn_embed_dim: 4096
|
| 11 |
+
decoder_layers: 12
|
| 12 |
+
decoder_attention_heads: 16
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: false
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_gbw.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "relu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 512
|
| 8 |
+
decoder_output_dim: 512
|
| 9 |
+
decoder_input_dim: 512
|
| 10 |
+
decoder_ffn_embed_dim: 4096
|
| 11 |
+
decoder_layers: 12
|
| 12 |
+
decoder_attention_heads: 16
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: true
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_gpt.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "gelu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 768
|
| 8 |
+
decoder_output_dim: 768
|
| 9 |
+
decoder_input_dim: 768
|
| 10 |
+
decoder_ffn_embed_dim: 3072
|
| 11 |
+
decoder_layers: 12
|
| 12 |
+
decoder_attention_heads: 12
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: false
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_gpt2_big.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "gelu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 1600
|
| 8 |
+
decoder_output_dim: 1600
|
| 9 |
+
decoder_input_dim: 1600
|
| 10 |
+
decoder_ffn_embed_dim: 6400
|
| 11 |
+
decoder_layers: 48
|
| 12 |
+
decoder_attention_heads: 25
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: false
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_gpt2_medium.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "gelu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 1280
|
| 8 |
+
decoder_output_dim: 1280
|
| 9 |
+
decoder_input_dim: 1280
|
| 10 |
+
decoder_ffn_embed_dim: 5120
|
| 11 |
+
decoder_layers: 36
|
| 12 |
+
decoder_attention_heads: 20
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: false
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_gpt2_small.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "gelu"
|
| 3 |
+
dropout: 0.1
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.0
|
| 6 |
+
relu_dropout: 0.0
|
| 7 |
+
decoder_embed_dim: 1024
|
| 8 |
+
decoder_output_dim: 1024
|
| 9 |
+
decoder_input_dim: 1024
|
| 10 |
+
decoder_ffn_embed_dim: 4096
|
| 11 |
+
decoder_layers: 24
|
| 12 |
+
decoder_attention_heads: 16
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: false
|
| 15 |
+
adaptive_softmax_cutoff: null
|
| 16 |
+
adaptive_softmax_dropout: 0
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: false
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: null
|
| 27 |
+
tie_adaptive_weights: false
|
| 28 |
+
tie_adaptive_proj: false
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/model/transformer_lm_wiki103.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
activation_fn: "relu"
|
| 3 |
+
dropout: 0.3
|
| 4 |
+
attention_dropout: 0.1
|
| 5 |
+
activation_dropout: 0.1
|
| 6 |
+
relu_dropout: 0.1
|
| 7 |
+
decoder_embed_dim: 1024
|
| 8 |
+
decoder_output_dim: 1024
|
| 9 |
+
decoder_input_dim: 1024
|
| 10 |
+
decoder_ffn_embed_dim: 4096
|
| 11 |
+
decoder_layers: 16
|
| 12 |
+
decoder_attention_heads: 8
|
| 13 |
+
decoder_normalize_before: true
|
| 14 |
+
no_decoder_final_norm: true
|
| 15 |
+
adaptive_softmax_cutoff: "20000,60000"
|
| 16 |
+
adaptive_softmax_dropout: 0.2
|
| 17 |
+
adaptive_softmax_factor: 4
|
| 18 |
+
no_token_positional_embeddings: false
|
| 19 |
+
share_decoder_input_output_embed: false
|
| 20 |
+
character_embeddings: false
|
| 21 |
+
character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
|
| 22 |
+
character_embedding_dim: 4
|
| 23 |
+
char_embedder_highway_layers: 2
|
| 24 |
+
adaptive_input: true
|
| 25 |
+
adaptive_input_factor: 4
|
| 26 |
+
adaptive_input_cutoff: "20000,60000"
|
| 27 |
+
tie_adaptive_weights: true
|
| 28 |
+
tie_adaptive_proj: true
|
| 29 |
+
decoder_learned_pos: false
|
| 30 |
+
decoder_layerdrop: 0
|
| 31 |
+
decoder_layers_to_keep: null
|
| 32 |
+
layernorm_embedding: false
|
| 33 |
+
no_scale_embedding: false
|
| 34 |
+
quant_noise_pq: 0
|
| 35 |
+
quant_noise_pq_block_size: 8
|
| 36 |
+
quant_noise_scalar: 0
|
fairseq-0.10.2/config/optimizer/adam.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
adam_betas: "(0.9, 0.999)"
|
| 3 |
+
adam_eps: 1.0e-8
|
| 4 |
+
weight_decay: 0
|
| 5 |
+
use_old_adam: false
|
fairseq-0.10.2/config/optimizer/nag.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
momentum: 0.99
|
| 3 |
+
weight_decay: 0.0
|
fairseq-0.10.2/config/params/eval_lm_params.yaml
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
common:
|
| 3 |
+
no_progress_bar: false
|
| 4 |
+
log_interval: 100
|
| 5 |
+
log_format: null
|
| 6 |
+
tensorboard_logdir: null
|
| 7 |
+
seed: 1
|
| 8 |
+
cpu: false
|
| 9 |
+
fp16: false
|
| 10 |
+
memory_efficient_fp16: false
|
| 11 |
+
fp16_no_flatten_grads: false
|
| 12 |
+
fp16_init_scale: 128
|
| 13 |
+
fp16_scale_window: null
|
| 14 |
+
fp16_scale_tolerance: 0.0
|
| 15 |
+
min_loss_scale: 1.0e-4
|
| 16 |
+
threshold_loss_scale: null
|
| 17 |
+
user_dir: null
|
| 18 |
+
empty_cache_freq: 0
|
| 19 |
+
all_gather_list_size: 16384
|
| 20 |
+
model_parallel_size: 1
|
| 21 |
+
checkpoint_suffix: ""
|
| 22 |
+
quantization_config_path: null
|
| 23 |
+
distributed_training:
|
| 24 |
+
distributed_rank: 0
|
| 25 |
+
distributed_backend: "nccl"
|
| 26 |
+
distributed_init_method: null
|
| 27 |
+
distributed_port: -1
|
| 28 |
+
device_id: 0
|
| 29 |
+
local_rank: 0
|
| 30 |
+
distributed_no_spawn: false
|
| 31 |
+
ddp_backend: "c10d"
|
| 32 |
+
bucket_cap_mb: 25
|
| 33 |
+
fix_batches_to_gpus: false
|
| 34 |
+
find_unused_parameters: false
|
| 35 |
+
fast_stat_sync: false
|
| 36 |
+
broadcast_buffers: false
|
| 37 |
+
distributed_wrapper: "DDP"
|
| 38 |
+
slowmo_momentum: null
|
| 39 |
+
slowmo_algorithm: "LocalSGD"
|
| 40 |
+
localsgd_frequency: 3
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 1
|
| 43 |
+
skip_invalid_size_inputs_valid_test: false
|
| 44 |
+
max_tokens: null
|
| 45 |
+
batch_size: ${params.dataset.batch_size}
|
| 46 |
+
required_batch_size_multiple: 8
|
| 47 |
+
dataset_impl: null
|
| 48 |
+
data_buffer_size: 10
|
| 49 |
+
train_subset: "train"
|
| 50 |
+
valid_subset: "valid"
|
| 51 |
+
validate_interval: 1
|
| 52 |
+
fixed_validation_seed: null
|
| 53 |
+
disable_validation: false
|
| 54 |
+
curriculum: 0
|
| 55 |
+
gen_subset: "test"
|
| 56 |
+
num_shards: 1
|
| 57 |
+
shard_id: 0
|
| 58 |
+
max_tokens_valid: ${params.dataset.max_tokens}
|
| 59 |
+
batch_size_valid: ${params.dataset.batch_size}
|
| 60 |
+
optimization:
|
| 61 |
+
max_epoch: 0
|
| 62 |
+
max_update: 0
|
| 63 |
+
clip_norm: 25.0
|
| 64 |
+
sentence_avg: false
|
| 65 |
+
update_freq: [1]
|
| 66 |
+
lr: [0.25]
|
| 67 |
+
min_lr: -1.0
|
| 68 |
+
use_bmuf: false
|
| 69 |
+
checkpoint:
|
| 70 |
+
save_dir: "checkpoints"
|
| 71 |
+
restore_file: "checkpoint_last.pt"
|
| 72 |
+
reset_dataloader: false
|
| 73 |
+
reset_lr_scheduler: false
|
| 74 |
+
reset_meters: false
|
| 75 |
+
reset_optimizer: false
|
| 76 |
+
optimizer_overrides: "{}"
|
| 77 |
+
save_interval: 1
|
| 78 |
+
save_interval_updates: 0
|
| 79 |
+
keep_interval_updates: -1
|
| 80 |
+
keep_last_epochs: -1
|
| 81 |
+
keep_best_checkpoints: -1
|
| 82 |
+
no_save: false
|
| 83 |
+
no_epoch_checkpoints: false
|
| 84 |
+
no_last_checkpoints: false
|
| 85 |
+
no_save_optimizer_state: false
|
| 86 |
+
best_checkpoint_metric: "loss"
|
| 87 |
+
maximize_best_checkpoint_metric: false
|
| 88 |
+
patience: -1
|
| 89 |
+
common_eval:
|
| 90 |
+
path: null
|
| 91 |
+
remove_bpe: null
|
| 92 |
+
quiet: false
|
| 93 |
+
model_overrides: "{}"
|
| 94 |
+
results_path: null
|
| 95 |
+
eval_lm:
|
| 96 |
+
output_word_probs: false
|
| 97 |
+
output_word_stats: false
|
| 98 |
+
context_window: 0
|
| 99 |
+
bmuf:
|
| 100 |
+
block_lr: 1
|
| 101 |
+
block_momentum: 0.875
|
| 102 |
+
global_sync_iter: 50
|
| 103 |
+
warmup_iterations: 500
|
| 104 |
+
use_nbm: false
|
| 105 |
+
average_sync: false
|
fairseq-0.10.2/config/params/training_params.yaml
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
common:
|
| 3 |
+
no_progress_bar: false
|
| 4 |
+
log_interval: 100
|
| 5 |
+
log_format: null
|
| 6 |
+
tensorboard_logdir: null
|
| 7 |
+
seed: 1
|
| 8 |
+
cpu: false
|
| 9 |
+
fp16: false
|
| 10 |
+
memory_efficient_fp16: false
|
| 11 |
+
fp16_no_flatten_grads: false
|
| 12 |
+
fp16_init_scale: 128
|
| 13 |
+
fp16_scale_window: null
|
| 14 |
+
fp16_scale_tolerance: 0.0
|
| 15 |
+
min_loss_scale: 1.0e-4
|
| 16 |
+
threshold_loss_scale: null
|
| 17 |
+
user_dir: null
|
| 18 |
+
empty_cache_freq: 0
|
| 19 |
+
all_gather_list_size: 16384
|
| 20 |
+
model_parallel_size: 1
|
| 21 |
+
checkpoint_suffix: ""
|
| 22 |
+
quantization_config_path: null
|
| 23 |
+
distributed_training:
|
| 24 |
+
distributed_rank: 0
|
| 25 |
+
distributed_backend: "nccl"
|
| 26 |
+
distributed_init_method: null
|
| 27 |
+
distributed_port: -1
|
| 28 |
+
device_id: 0
|
| 29 |
+
local_rank: 0
|
| 30 |
+
distributed_no_spawn: false
|
| 31 |
+
ddp_backend: "c10d"
|
| 32 |
+
bucket_cap_mb: 25
|
| 33 |
+
fix_batches_to_gpus: false
|
| 34 |
+
find_unused_parameters: false
|
| 35 |
+
fast_stat_sync: false
|
| 36 |
+
broadcast_buffers: false
|
| 37 |
+
distributed_wrapper: "DDP"
|
| 38 |
+
slowmo_momentum: null
|
| 39 |
+
slowmo_algorithm: "LocalSGD"
|
| 40 |
+
localsgd_frequency: 3
|
| 41 |
+
dataset:
|
| 42 |
+
num_workers: 1
|
| 43 |
+
skip_invalid_size_inputs_valid_test: false
|
| 44 |
+
max_tokens: null
|
| 45 |
+
batch_size: ${params.dataset.batch_size}
|
| 46 |
+
required_batch_size_multiple: 8
|
| 47 |
+
dataset_impl: null
|
| 48 |
+
data_buffer_size: 10
|
| 49 |
+
train_subset: "train"
|
| 50 |
+
valid_subset: "valid"
|
| 51 |
+
validate_interval: 1
|
| 52 |
+
fixed_validation_seed: null
|
| 53 |
+
disable_validation: false
|
| 54 |
+
curriculum: 0
|
| 55 |
+
gen_subset: "test"
|
| 56 |
+
num_shards: 1
|
| 57 |
+
shard_id: 0
|
| 58 |
+
max_tokens_valid: ${params.dataset.max_tokens}
|
| 59 |
+
batch_size_valid: ${params.dataset.batch_size}
|
| 60 |
+
optimization:
|
| 61 |
+
max_epoch: 0
|
| 62 |
+
max_update: 0
|
| 63 |
+
clip_norm: 25.0
|
| 64 |
+
sentence_avg: false
|
| 65 |
+
update_freq: [1]
|
| 66 |
+
lr: [0.25]
|
| 67 |
+
min_lr: -1.0
|
| 68 |
+
use_bmuf: false
|
| 69 |
+
checkpoint:
|
| 70 |
+
save_dir: "checkpoints"
|
| 71 |
+
restore_file: "checkpoint_last.pt"
|
| 72 |
+
reset_dataloader: false
|
| 73 |
+
reset_lr_scheduler: false
|
| 74 |
+
reset_meters: false
|
| 75 |
+
reset_optimizer: false
|
| 76 |
+
optimizer_overrides: "{}"
|
| 77 |
+
save_interval: 1
|
| 78 |
+
save_interval_updates: 0
|
| 79 |
+
keep_interval_updates: -1
|
| 80 |
+
keep_last_epochs: -1
|
| 81 |
+
keep_best_checkpoints: -1
|
| 82 |
+
no_save: false
|
| 83 |
+
no_epoch_checkpoints: false
|
| 84 |
+
no_last_checkpoints: false
|
| 85 |
+
no_save_optimizer_state: false
|
| 86 |
+
best_checkpoint_metric: "loss"
|
| 87 |
+
maximize_best_checkpoint_metric: false
|
| 88 |
+
patience: -1
|
| 89 |
+
bmuf:
|
| 90 |
+
block_lr: 1
|
| 91 |
+
block_momentum: 0.875
|
| 92 |
+
global_sync_iter: 50
|
| 93 |
+
warmup_iterations: 500
|
| 94 |
+
use_nbm: false
|
| 95 |
+
average_sync: false
|
fairseq-0.10.2/config/task/language_modeling.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# @package _group_
|
| 2 |
+
data: ???
|
| 3 |
+
sample_break_mode: "none"
|
| 4 |
+
tokens_per_sample: 1024
|
| 5 |
+
output_dictionary_size: -1
|
| 6 |
+
self_target: false
|
| 7 |
+
future_target: false
|
| 8 |
+
past_target: false
|
| 9 |
+
add_bos_token: false
|
| 10 |
+
max_target_positions: null
|
fairseq-0.10.2/examples/noisychannel/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019)
|
| 2 |
+
This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts.
|
| 3 |
+
|
| 4 |
+
## Citation:
|
| 5 |
+
```bibtex
|
| 6 |
+
@inproceedings{yee2019simple,
|
| 7 |
+
title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
|
| 8 |
+
author = {Kyra Yee and Yann Dauphin and Michael Auli},
|
| 9 |
+
booktitle = {Conference on Empirical Methods in Natural Language Processing},
|
| 10 |
+
year = {2019},
|
| 11 |
+
}
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Pre-trained Models:
|
| 15 |
+
|
| 16 |
+
Model | Description | Download
|
| 17 |
+
---|---|---
|
| 18 |
+
`transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2)
|
| 19 |
+
`transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2)
|
| 20 |
+
`transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2)
|
| 21 |
+
|
| 22 |
+
Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2)
|
| 23 |
+
|
| 24 |
+
## Example usage
|
| 25 |
+
|
| 26 |
+
```
|
| 27 |
+
mkdir rerank_example
|
| 28 |
+
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example
|
| 29 |
+
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example
|
| 30 |
+
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example
|
| 31 |
+
curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example
|
| 32 |
+
|
| 33 |
+
beam=50
|
| 34 |
+
num_trials=1000
|
| 35 |
+
fw_name=fw_model_ex
|
| 36 |
+
bw_name=bw_model_ex
|
| 37 |
+
lm_name=lm_ex
|
| 38 |
+
data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe
|
| 39 |
+
data_dir_name=wmt17
|
| 40 |
+
lm=rerank_example/lm/checkpoint_best.pt
|
| 41 |
+
lm_bpe_code=rerank_example/lm/bpe32k.code
|
| 42 |
+
lm_dict=rerank_example/lm/dict.txt
|
| 43 |
+
batch_size=32
|
| 44 |
+
bw=rerank_example/backward_en2de.pt
|
| 45 |
+
fw=rerank_example/forward_de2en.pt
|
| 46 |
+
|
| 47 |
+
# reranking with P(T|S) P(S|T) and P(T)
|
| 48 |
+
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \
|
| 49 |
+
--lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \
|
| 50 |
+
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
|
| 51 |
+
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \
|
| 52 |
+
--backwards1 --weight2 1 \
|
| 53 |
+
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
|
| 54 |
+
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
|
| 55 |
+
|
| 56 |
+
# reranking with P(T|S) and P(T)
|
| 57 |
+
python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \
|
| 58 |
+
--lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \
|
| 59 |
+
--num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
|
| 60 |
+
-n $beam --batch-size $batch_size --score-model1 $fw \
|
| 61 |
+
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
|
| 62 |
+
--model1-name $fw_name --gen-model-name $fw_name
|
| 63 |
+
|
| 64 |
+
# to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead.
|
| 65 |
+
python examples/noisychannel/rerank.py $data_dir \
|
| 66 |
+
--lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \
|
| 67 |
+
--data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \
|
| 68 |
+
-n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \
|
| 69 |
+
-lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
|
| 70 |
+
--model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
|
| 71 |
+
```
|
| 72 |
+
|
fairseq-0.10.2/examples/noisychannel/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .rerank_options import * # noqa
|
fairseq-0.10.2/examples/noisychannel/rerank.py
ADDED
|
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
from multiprocessing import Pool
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from fairseq import options
|
| 11 |
+
from fairseq.data import dictionary
|
| 12 |
+
from fairseq.scoring import bleu
|
| 13 |
+
|
| 14 |
+
from . import (
|
| 15 |
+
rerank_generate,
|
| 16 |
+
rerank_options,
|
| 17 |
+
rerank_score_bw,
|
| 18 |
+
rerank_score_lm,
|
| 19 |
+
rerank_utils,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def score_target_hypo(
|
| 24 |
+
args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
|
| 25 |
+
):
|
| 26 |
+
|
| 27 |
+
print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
|
| 28 |
+
gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
|
| 29 |
+
dict = dictionary.Dictionary()
|
| 30 |
+
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
|
| 31 |
+
|
| 32 |
+
ordered_hypos = {}
|
| 33 |
+
ordered_targets = {}
|
| 34 |
+
|
| 35 |
+
for shard_id in range(len(bitext1_lst)):
|
| 36 |
+
bitext1 = bitext1_lst[shard_id]
|
| 37 |
+
bitext2 = bitext2_lst[shard_id]
|
| 38 |
+
gen_output = gen_output_lst[shard_id]
|
| 39 |
+
lm_res = lm_res_lst[shard_id]
|
| 40 |
+
|
| 41 |
+
total = len(bitext1.rescore_source.keys())
|
| 42 |
+
source_lst = []
|
| 43 |
+
hypo_lst = []
|
| 44 |
+
score_lst = []
|
| 45 |
+
reference_lst = []
|
| 46 |
+
j = 1
|
| 47 |
+
best_score = -math.inf
|
| 48 |
+
|
| 49 |
+
for i in range(total):
|
| 50 |
+
# length is measured in terms of words, not bpe tokens, since models may not share the same bpe
|
| 51 |
+
target_len = len(bitext1.rescore_hypo[i].split())
|
| 52 |
+
|
| 53 |
+
if lm_res is not None:
|
| 54 |
+
lm_score = lm_res.score[i]
|
| 55 |
+
else:
|
| 56 |
+
lm_score = 0
|
| 57 |
+
|
| 58 |
+
if bitext2 is not None:
|
| 59 |
+
bitext2_score = bitext2.rescore_score[i]
|
| 60 |
+
bitext2_backwards = bitext2.backwards
|
| 61 |
+
else:
|
| 62 |
+
bitext2_score = None
|
| 63 |
+
bitext2_backwards = None
|
| 64 |
+
|
| 65 |
+
score = rerank_utils.get_score(
|
| 66 |
+
a,
|
| 67 |
+
b,
|
| 68 |
+
c,
|
| 69 |
+
target_len,
|
| 70 |
+
bitext1.rescore_score[i],
|
| 71 |
+
bitext2_score,
|
| 72 |
+
lm_score=lm_score,
|
| 73 |
+
lenpen=lenpen,
|
| 74 |
+
src_len=bitext1.source_lengths[i],
|
| 75 |
+
tgt_len=bitext1.target_lengths[i],
|
| 76 |
+
bitext1_backwards=bitext1.backwards,
|
| 77 |
+
bitext2_backwards=bitext2_backwards,
|
| 78 |
+
normalize=normalize,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
if score > best_score:
|
| 82 |
+
best_score = score
|
| 83 |
+
best_hypo = bitext1.rescore_hypo[i]
|
| 84 |
+
|
| 85 |
+
if j == gen_output.num_hypos[i] or j == args.num_rescore:
|
| 86 |
+
j = 1
|
| 87 |
+
hypo_lst.append(best_hypo)
|
| 88 |
+
score_lst.append(best_score)
|
| 89 |
+
source_lst.append(bitext1.rescore_source[i])
|
| 90 |
+
reference_lst.append(bitext1.rescore_target[i])
|
| 91 |
+
|
| 92 |
+
best_score = -math.inf
|
| 93 |
+
best_hypo = ""
|
| 94 |
+
else:
|
| 95 |
+
j += 1
|
| 96 |
+
|
| 97 |
+
gen_keys = list(sorted(gen_output.no_bpe_target.keys()))
|
| 98 |
+
|
| 99 |
+
for key in range(len(gen_keys)):
|
| 100 |
+
if args.prefix_len is None:
|
| 101 |
+
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
| 102 |
+
"pred and rescore hypo mismatch: i: "
|
| 103 |
+
+ str(key)
|
| 104 |
+
+ ", "
|
| 105 |
+
+ str(hypo_lst[key])
|
| 106 |
+
+ str(gen_keys[key])
|
| 107 |
+
+ str(gen_output.no_bpe_hypo[key])
|
| 108 |
+
)
|
| 109 |
+
sys_tok = dict.encode_line(hypo_lst[key])
|
| 110 |
+
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
|
| 111 |
+
scorer.add(ref_tok, sys_tok)
|
| 112 |
+
|
| 113 |
+
else:
|
| 114 |
+
full_hypo = rerank_utils.get_full_from_prefix(
|
| 115 |
+
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
|
| 116 |
+
)
|
| 117 |
+
sys_tok = dict.encode_line(full_hypo)
|
| 118 |
+
ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
|
| 119 |
+
scorer.add(ref_tok, sys_tok)
|
| 120 |
+
|
| 121 |
+
# if only one set of hyper parameters is provided, write the predictions to a file
|
| 122 |
+
if write_hypos:
|
| 123 |
+
# recover the orinal ids from n best list generation
|
| 124 |
+
for key in range(len(gen_output.no_bpe_target)):
|
| 125 |
+
if args.prefix_len is None:
|
| 126 |
+
assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
|
| 127 |
+
"pred and rescore hypo mismatch:"
|
| 128 |
+
+ "i:"
|
| 129 |
+
+ str(key)
|
| 130 |
+
+ str(hypo_lst[key])
|
| 131 |
+
+ str(gen_output.no_bpe_hypo[key])
|
| 132 |
+
)
|
| 133 |
+
ordered_hypos[gen_keys[key]] = hypo_lst[key]
|
| 134 |
+
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
|
| 135 |
+
gen_keys[key]
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
else:
|
| 139 |
+
full_hypo = rerank_utils.get_full_from_prefix(
|
| 140 |
+
hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
|
| 141 |
+
)
|
| 142 |
+
ordered_hypos[gen_keys[key]] = full_hypo
|
| 143 |
+
ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
|
| 144 |
+
gen_keys[key]
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
# write the hypos in the original order from nbest list generation
|
| 148 |
+
if args.num_shards == (len(bitext1_lst)):
|
| 149 |
+
with open(target_outfile, "w") as t:
|
| 150 |
+
with open(hypo_outfile, "w") as h:
|
| 151 |
+
for key in range(len(ordered_hypos)):
|
| 152 |
+
t.write(ordered_targets[key])
|
| 153 |
+
h.write(ordered_hypos[key])
|
| 154 |
+
|
| 155 |
+
res = scorer.result_string(4)
|
| 156 |
+
if write_hypos:
|
| 157 |
+
print(res)
|
| 158 |
+
score = rerank_utils.parse_bleu_scoring(res)
|
| 159 |
+
return score
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def match_target_hypo(args, target_outfile, hypo_outfile):
|
| 163 |
+
"""combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
|
| 164 |
+
if len(args.weight1) == 1:
|
| 165 |
+
res = score_target_hypo(
|
| 166 |
+
args,
|
| 167 |
+
args.weight1[0],
|
| 168 |
+
args.weight2[0],
|
| 169 |
+
args.weight3[0],
|
| 170 |
+
args.lenpen[0],
|
| 171 |
+
target_outfile,
|
| 172 |
+
hypo_outfile,
|
| 173 |
+
True,
|
| 174 |
+
args.normalize,
|
| 175 |
+
)
|
| 176 |
+
rerank_scores = [res]
|
| 177 |
+
else:
|
| 178 |
+
print("launching pool")
|
| 179 |
+
with Pool(32) as p:
|
| 180 |
+
rerank_scores = p.starmap(
|
| 181 |
+
score_target_hypo,
|
| 182 |
+
[
|
| 183 |
+
(
|
| 184 |
+
args,
|
| 185 |
+
args.weight1[i],
|
| 186 |
+
args.weight2[i],
|
| 187 |
+
args.weight3[i],
|
| 188 |
+
args.lenpen[i],
|
| 189 |
+
target_outfile,
|
| 190 |
+
hypo_outfile,
|
| 191 |
+
False,
|
| 192 |
+
args.normalize,
|
| 193 |
+
)
|
| 194 |
+
for i in range(len(args.weight1))
|
| 195 |
+
],
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if len(rerank_scores) > 1:
|
| 199 |
+
best_index = np.argmax(rerank_scores)
|
| 200 |
+
best_score = rerank_scores[best_index]
|
| 201 |
+
print("best score", best_score)
|
| 202 |
+
print("best lenpen", args.lenpen[best_index])
|
| 203 |
+
print("best weight1", args.weight1[best_index])
|
| 204 |
+
print("best weight2", args.weight2[best_index])
|
| 205 |
+
print("best weight3", args.weight3[best_index])
|
| 206 |
+
return (
|
| 207 |
+
args.lenpen[best_index],
|
| 208 |
+
args.weight1[best_index],
|
| 209 |
+
args.weight2[best_index],
|
| 210 |
+
args.weight3[best_index],
|
| 211 |
+
best_score,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
else:
|
| 215 |
+
return (
|
| 216 |
+
args.lenpen[0],
|
| 217 |
+
args.weight1[0],
|
| 218 |
+
args.weight2[0],
|
| 219 |
+
args.weight3[0],
|
| 220 |
+
rerank_scores[0],
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def load_score_files(args):
|
| 225 |
+
if args.all_shards:
|
| 226 |
+
shard_ids = list(range(args.num_shards))
|
| 227 |
+
else:
|
| 228 |
+
shard_ids = [args.shard_id]
|
| 229 |
+
|
| 230 |
+
gen_output_lst = []
|
| 231 |
+
bitext1_lst = []
|
| 232 |
+
bitext2_lst = []
|
| 233 |
+
lm_res1_lst = []
|
| 234 |
+
|
| 235 |
+
for shard_id in shard_ids:
|
| 236 |
+
using_nbest = args.nbest_list is not None
|
| 237 |
+
(
|
| 238 |
+
pre_gen,
|
| 239 |
+
left_to_right_preprocessed_dir,
|
| 240 |
+
right_to_left_preprocessed_dir,
|
| 241 |
+
backwards_preprocessed_dir,
|
| 242 |
+
lm_preprocessed_dir,
|
| 243 |
+
) = rerank_utils.get_directories(
|
| 244 |
+
args.data_dir_name,
|
| 245 |
+
args.num_rescore,
|
| 246 |
+
args.gen_subset,
|
| 247 |
+
args.gen_model_name,
|
| 248 |
+
shard_id,
|
| 249 |
+
args.num_shards,
|
| 250 |
+
args.sampling,
|
| 251 |
+
args.prefix_len,
|
| 252 |
+
args.target_prefix_frac,
|
| 253 |
+
args.source_prefix_frac,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
rerank1_is_gen = (
|
| 257 |
+
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
| 258 |
+
)
|
| 259 |
+
rerank2_is_gen = (
|
| 260 |
+
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
score1_file = rerank_utils.rescore_file_name(
|
| 264 |
+
pre_gen,
|
| 265 |
+
args.prefix_len,
|
| 266 |
+
args.model1_name,
|
| 267 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 268 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 269 |
+
backwards=args.backwards1,
|
| 270 |
+
)
|
| 271 |
+
if args.score_model2 is not None:
|
| 272 |
+
score2_file = rerank_utils.rescore_file_name(
|
| 273 |
+
pre_gen,
|
| 274 |
+
args.prefix_len,
|
| 275 |
+
args.model2_name,
|
| 276 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 277 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 278 |
+
backwards=args.backwards2,
|
| 279 |
+
)
|
| 280 |
+
if args.language_model is not None:
|
| 281 |
+
lm_score_file = rerank_utils.rescore_file_name(
|
| 282 |
+
pre_gen, args.prefix_len, args.lm_name, lm_file=True
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# get gen output
|
| 286 |
+
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
| 287 |
+
if using_nbest:
|
| 288 |
+
print("Using predefined n-best list from interactive.py")
|
| 289 |
+
predictions_bpe_file = args.nbest_list
|
| 290 |
+
gen_output = rerank_utils.BitextOutputFromGen(
|
| 291 |
+
predictions_bpe_file,
|
| 292 |
+
bpe_symbol=args.remove_bpe,
|
| 293 |
+
nbest=using_nbest,
|
| 294 |
+
prefix_len=args.prefix_len,
|
| 295 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if rerank1_is_gen:
|
| 299 |
+
bitext1 = gen_output
|
| 300 |
+
else:
|
| 301 |
+
bitext1 = rerank_utils.BitextOutput(
|
| 302 |
+
score1_file,
|
| 303 |
+
args.backwards1,
|
| 304 |
+
args.right_to_left1,
|
| 305 |
+
args.remove_bpe,
|
| 306 |
+
args.prefix_len,
|
| 307 |
+
args.target_prefix_frac,
|
| 308 |
+
args.source_prefix_frac,
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
if args.score_model2 is not None or args.nbest_list is not None:
|
| 312 |
+
if rerank2_is_gen:
|
| 313 |
+
bitext2 = gen_output
|
| 314 |
+
else:
|
| 315 |
+
bitext2 = rerank_utils.BitextOutput(
|
| 316 |
+
score2_file,
|
| 317 |
+
args.backwards2,
|
| 318 |
+
args.right_to_left2,
|
| 319 |
+
args.remove_bpe,
|
| 320 |
+
args.prefix_len,
|
| 321 |
+
args.target_prefix_frac,
|
| 322 |
+
args.source_prefix_frac,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
assert (
|
| 326 |
+
bitext2.source_lengths == bitext1.source_lengths
|
| 327 |
+
), "source lengths for rescoring models do not match"
|
| 328 |
+
assert (
|
| 329 |
+
bitext2.target_lengths == bitext1.target_lengths
|
| 330 |
+
), "target lengths for rescoring models do not match"
|
| 331 |
+
else:
|
| 332 |
+
if args.diff_bpe:
|
| 333 |
+
assert args.score_model2 is None
|
| 334 |
+
bitext2 = gen_output
|
| 335 |
+
else:
|
| 336 |
+
bitext2 = None
|
| 337 |
+
|
| 338 |
+
if args.language_model is not None:
|
| 339 |
+
lm_res1 = rerank_utils.LMOutput(
|
| 340 |
+
lm_score_file,
|
| 341 |
+
args.lm_dict,
|
| 342 |
+
args.prefix_len,
|
| 343 |
+
args.remove_bpe,
|
| 344 |
+
args.target_prefix_frac,
|
| 345 |
+
)
|
| 346 |
+
else:
|
| 347 |
+
lm_res1 = None
|
| 348 |
+
|
| 349 |
+
gen_output_lst.append(gen_output)
|
| 350 |
+
bitext1_lst.append(bitext1)
|
| 351 |
+
bitext2_lst.append(bitext2)
|
| 352 |
+
lm_res1_lst.append(lm_res1)
|
| 353 |
+
return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def rerank(args):
|
| 357 |
+
if type(args.lenpen) is not list:
|
| 358 |
+
args.lenpen = [args.lenpen]
|
| 359 |
+
if type(args.weight1) is not list:
|
| 360 |
+
args.weight1 = [args.weight1]
|
| 361 |
+
if type(args.weight2) is not list:
|
| 362 |
+
args.weight2 = [args.weight2]
|
| 363 |
+
if type(args.weight3) is not list:
|
| 364 |
+
args.weight3 = [args.weight3]
|
| 365 |
+
if args.all_shards:
|
| 366 |
+
shard_ids = list(range(args.num_shards))
|
| 367 |
+
else:
|
| 368 |
+
shard_ids = [args.shard_id]
|
| 369 |
+
|
| 370 |
+
for shard_id in shard_ids:
|
| 371 |
+
(
|
| 372 |
+
pre_gen,
|
| 373 |
+
left_to_right_preprocessed_dir,
|
| 374 |
+
right_to_left_preprocessed_dir,
|
| 375 |
+
backwards_preprocessed_dir,
|
| 376 |
+
lm_preprocessed_dir,
|
| 377 |
+
) = rerank_utils.get_directories(
|
| 378 |
+
args.data_dir_name,
|
| 379 |
+
args.num_rescore,
|
| 380 |
+
args.gen_subset,
|
| 381 |
+
args.gen_model_name,
|
| 382 |
+
shard_id,
|
| 383 |
+
args.num_shards,
|
| 384 |
+
args.sampling,
|
| 385 |
+
args.prefix_len,
|
| 386 |
+
args.target_prefix_frac,
|
| 387 |
+
args.source_prefix_frac,
|
| 388 |
+
)
|
| 389 |
+
rerank_generate.gen_and_reprocess_nbest(args)
|
| 390 |
+
rerank_score_bw.score_bw(args)
|
| 391 |
+
rerank_score_lm.score_lm(args)
|
| 392 |
+
|
| 393 |
+
if args.write_hypos is None:
|
| 394 |
+
write_targets = pre_gen + "/matched_targets"
|
| 395 |
+
write_hypos = pre_gen + "/matched_hypos"
|
| 396 |
+
else:
|
| 397 |
+
write_targets = args.write_hypos + "_targets" + args.gen_subset
|
| 398 |
+
write_hypos = args.write_hypos + "_hypos" + args.gen_subset
|
| 399 |
+
|
| 400 |
+
if args.all_shards:
|
| 401 |
+
write_targets += "_all_shards"
|
| 402 |
+
write_hypos += "_all_shards"
|
| 403 |
+
|
| 404 |
+
(
|
| 405 |
+
best_lenpen,
|
| 406 |
+
best_weight1,
|
| 407 |
+
best_weight2,
|
| 408 |
+
best_weight3,
|
| 409 |
+
best_score,
|
| 410 |
+
) = match_target_hypo(args, write_targets, write_hypos)
|
| 411 |
+
|
| 412 |
+
return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def cli_main():
|
| 416 |
+
parser = rerank_options.get_reranking_parser()
|
| 417 |
+
args = options.parse_args_and_arch(parser)
|
| 418 |
+
rerank(args)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
if __name__ == "__main__":
|
| 422 |
+
cli_main()
|
fairseq-0.10.2/examples/noisychannel/rerank_generate.py
ADDED
|
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the MIT license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""
|
| 8 |
+
Generate n-best translations using a trained model.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
from contextlib import redirect_stdout
|
| 14 |
+
|
| 15 |
+
from fairseq import options
|
| 16 |
+
from fairseq_cli import generate, preprocess
|
| 17 |
+
|
| 18 |
+
from . import rerank_options, rerank_utils
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def gen_and_reprocess_nbest(args):
|
| 22 |
+
if args.score_dict_dir is None:
|
| 23 |
+
args.score_dict_dir = args.data
|
| 24 |
+
if args.prefix_len is not None:
|
| 25 |
+
assert (
|
| 26 |
+
args.right_to_left1 is False
|
| 27 |
+
), "prefix length not compatible with right to left models"
|
| 28 |
+
assert (
|
| 29 |
+
args.right_to_left2 is False
|
| 30 |
+
), "prefix length not compatible with right to left models"
|
| 31 |
+
|
| 32 |
+
if args.nbest_list is not None:
|
| 33 |
+
assert args.score_model2 is None
|
| 34 |
+
|
| 35 |
+
if args.backwards1:
|
| 36 |
+
scorer1_src = args.target_lang
|
| 37 |
+
scorer1_tgt = args.source_lang
|
| 38 |
+
else:
|
| 39 |
+
scorer1_src = args.source_lang
|
| 40 |
+
scorer1_tgt = args.target_lang
|
| 41 |
+
|
| 42 |
+
store_data = (
|
| 43 |
+
os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
|
| 44 |
+
)
|
| 45 |
+
if not os.path.exists(store_data):
|
| 46 |
+
os.makedirs(store_data)
|
| 47 |
+
|
| 48 |
+
(
|
| 49 |
+
pre_gen,
|
| 50 |
+
left_to_right_preprocessed_dir,
|
| 51 |
+
right_to_left_preprocessed_dir,
|
| 52 |
+
backwards_preprocessed_dir,
|
| 53 |
+
lm_preprocessed_dir,
|
| 54 |
+
) = rerank_utils.get_directories(
|
| 55 |
+
args.data_dir_name,
|
| 56 |
+
args.num_rescore,
|
| 57 |
+
args.gen_subset,
|
| 58 |
+
args.gen_model_name,
|
| 59 |
+
args.shard_id,
|
| 60 |
+
args.num_shards,
|
| 61 |
+
args.sampling,
|
| 62 |
+
args.prefix_len,
|
| 63 |
+
args.target_prefix_frac,
|
| 64 |
+
args.source_prefix_frac,
|
| 65 |
+
)
|
| 66 |
+
assert not (
|
| 67 |
+
args.right_to_left1 and args.backwards1
|
| 68 |
+
), "backwards right to left not supported"
|
| 69 |
+
assert not (
|
| 70 |
+
args.right_to_left2 and args.backwards2
|
| 71 |
+
), "backwards right to left not supported"
|
| 72 |
+
assert not (
|
| 73 |
+
args.prefix_len is not None and args.target_prefix_frac is not None
|
| 74 |
+
), "target prefix frac and target prefix len incompatible"
|
| 75 |
+
|
| 76 |
+
# make directory to store generation results
|
| 77 |
+
if not os.path.exists(pre_gen):
|
| 78 |
+
os.makedirs(pre_gen)
|
| 79 |
+
|
| 80 |
+
rerank1_is_gen = (
|
| 81 |
+
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
| 82 |
+
)
|
| 83 |
+
rerank2_is_gen = (
|
| 84 |
+
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
if args.nbest_list is not None:
|
| 88 |
+
rerank2_is_gen = True
|
| 89 |
+
|
| 90 |
+
# make directories to store preprossed nbest list for reranking
|
| 91 |
+
if not os.path.exists(left_to_right_preprocessed_dir):
|
| 92 |
+
os.makedirs(left_to_right_preprocessed_dir)
|
| 93 |
+
if not os.path.exists(right_to_left_preprocessed_dir):
|
| 94 |
+
os.makedirs(right_to_left_preprocessed_dir)
|
| 95 |
+
if not os.path.exists(lm_preprocessed_dir):
|
| 96 |
+
os.makedirs(lm_preprocessed_dir)
|
| 97 |
+
if not os.path.exists(backwards_preprocessed_dir):
|
| 98 |
+
os.makedirs(backwards_preprocessed_dir)
|
| 99 |
+
|
| 100 |
+
score1_file = rerank_utils.rescore_file_name(
|
| 101 |
+
pre_gen,
|
| 102 |
+
args.prefix_len,
|
| 103 |
+
args.model1_name,
|
| 104 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 105 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 106 |
+
backwards=args.backwards1,
|
| 107 |
+
)
|
| 108 |
+
if args.score_model2 is not None:
|
| 109 |
+
score2_file = rerank_utils.rescore_file_name(
|
| 110 |
+
pre_gen,
|
| 111 |
+
args.prefix_len,
|
| 112 |
+
args.model2_name,
|
| 113 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 114 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 115 |
+
backwards=args.backwards2,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
|
| 119 |
+
|
| 120 |
+
using_nbest = args.nbest_list is not None
|
| 121 |
+
|
| 122 |
+
if using_nbest:
|
| 123 |
+
print("Using predefined n-best list from interactive.py")
|
| 124 |
+
predictions_bpe_file = args.nbest_list
|
| 125 |
+
|
| 126 |
+
else:
|
| 127 |
+
if not os.path.isfile(predictions_bpe_file):
|
| 128 |
+
print("STEP 1: generate predictions using the p(T|S) model with bpe")
|
| 129 |
+
print(args.data)
|
| 130 |
+
param1 = [
|
| 131 |
+
args.data,
|
| 132 |
+
"--path",
|
| 133 |
+
args.gen_model,
|
| 134 |
+
"--shard-id",
|
| 135 |
+
str(args.shard_id),
|
| 136 |
+
"--num-shards",
|
| 137 |
+
str(args.num_shards),
|
| 138 |
+
"--nbest",
|
| 139 |
+
str(args.num_rescore),
|
| 140 |
+
"--batch-size",
|
| 141 |
+
str(args.batch_size),
|
| 142 |
+
"--beam",
|
| 143 |
+
str(args.num_rescore),
|
| 144 |
+
"--batch-size",
|
| 145 |
+
str(args.num_rescore),
|
| 146 |
+
"--gen-subset",
|
| 147 |
+
args.gen_subset,
|
| 148 |
+
"--source-lang",
|
| 149 |
+
args.source_lang,
|
| 150 |
+
"--target-lang",
|
| 151 |
+
args.target_lang,
|
| 152 |
+
]
|
| 153 |
+
if args.sampling:
|
| 154 |
+
param1 += ["--sampling"]
|
| 155 |
+
|
| 156 |
+
gen_parser = options.get_generation_parser()
|
| 157 |
+
input_args = options.parse_args_and_arch(gen_parser, param1)
|
| 158 |
+
|
| 159 |
+
print(input_args)
|
| 160 |
+
with open(predictions_bpe_file, "w") as f:
|
| 161 |
+
with redirect_stdout(f):
|
| 162 |
+
generate.main(input_args)
|
| 163 |
+
|
| 164 |
+
gen_output = rerank_utils.BitextOutputFromGen(
|
| 165 |
+
predictions_bpe_file,
|
| 166 |
+
bpe_symbol=args.remove_bpe,
|
| 167 |
+
nbest=using_nbest,
|
| 168 |
+
prefix_len=args.prefix_len,
|
| 169 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
if args.diff_bpe:
|
| 173 |
+
rerank_utils.write_reprocessed(
|
| 174 |
+
gen_output.no_bpe_source,
|
| 175 |
+
gen_output.no_bpe_hypo,
|
| 176 |
+
gen_output.no_bpe_target,
|
| 177 |
+
pre_gen + "/source_gen_bpe." + args.source_lang,
|
| 178 |
+
pre_gen + "/target_gen_bpe." + args.target_lang,
|
| 179 |
+
pre_gen + "/reference_gen_bpe." + args.target_lang,
|
| 180 |
+
)
|
| 181 |
+
bitext_bpe = args.rescore_bpe_code
|
| 182 |
+
bpe_src_param = [
|
| 183 |
+
"-c",
|
| 184 |
+
bitext_bpe,
|
| 185 |
+
"--input",
|
| 186 |
+
pre_gen + "/source_gen_bpe." + args.source_lang,
|
| 187 |
+
"--output",
|
| 188 |
+
pre_gen + "/rescore_data." + args.source_lang,
|
| 189 |
+
]
|
| 190 |
+
bpe_tgt_param = [
|
| 191 |
+
"-c",
|
| 192 |
+
bitext_bpe,
|
| 193 |
+
"--input",
|
| 194 |
+
pre_gen + "/target_gen_bpe." + args.target_lang,
|
| 195 |
+
"--output",
|
| 196 |
+
pre_gen + "/rescore_data." + args.target_lang,
|
| 197 |
+
]
|
| 198 |
+
|
| 199 |
+
subprocess.call(
|
| 200 |
+
[
|
| 201 |
+
"python",
|
| 202 |
+
os.path.join(
|
| 203 |
+
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
| 204 |
+
),
|
| 205 |
+
]
|
| 206 |
+
+ bpe_src_param,
|
| 207 |
+
shell=False,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
subprocess.call(
|
| 211 |
+
[
|
| 212 |
+
"python",
|
| 213 |
+
os.path.join(
|
| 214 |
+
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
| 215 |
+
),
|
| 216 |
+
]
|
| 217 |
+
+ bpe_tgt_param,
|
| 218 |
+
shell=False,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
|
| 222 |
+
args.score_model2 is not None
|
| 223 |
+
and not os.path.isfile(score2_file)
|
| 224 |
+
and not rerank2_is_gen
|
| 225 |
+
):
|
| 226 |
+
print(
|
| 227 |
+
"STEP 2: process the output of generate.py so we have clean text files with the translations"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
rescore_file = "/rescore_data"
|
| 231 |
+
if args.prefix_len is not None:
|
| 232 |
+
prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
|
| 233 |
+
if args.target_prefix_frac is not None:
|
| 234 |
+
target_prefix_frac_rescore_file = (
|
| 235 |
+
rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
|
| 236 |
+
)
|
| 237 |
+
if args.source_prefix_frac is not None:
|
| 238 |
+
source_prefix_frac_rescore_file = (
|
| 239 |
+
rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
if not args.right_to_left1 or not args.right_to_left2:
|
| 243 |
+
if not args.diff_bpe:
|
| 244 |
+
rerank_utils.write_reprocessed(
|
| 245 |
+
gen_output.source,
|
| 246 |
+
gen_output.hypo,
|
| 247 |
+
gen_output.target,
|
| 248 |
+
pre_gen + rescore_file + "." + args.source_lang,
|
| 249 |
+
pre_gen + rescore_file + "." + args.target_lang,
|
| 250 |
+
pre_gen + "/reference_file",
|
| 251 |
+
bpe_symbol=args.remove_bpe,
|
| 252 |
+
)
|
| 253 |
+
if args.prefix_len is not None:
|
| 254 |
+
bw_rescore_file = prefix_len_rescore_file
|
| 255 |
+
rerank_utils.write_reprocessed(
|
| 256 |
+
gen_output.source,
|
| 257 |
+
gen_output.hypo,
|
| 258 |
+
gen_output.target,
|
| 259 |
+
pre_gen + prefix_len_rescore_file + "." + args.source_lang,
|
| 260 |
+
pre_gen + prefix_len_rescore_file + "." + args.target_lang,
|
| 261 |
+
pre_gen + "/reference_file",
|
| 262 |
+
prefix_len=args.prefix_len,
|
| 263 |
+
bpe_symbol=args.remove_bpe,
|
| 264 |
+
)
|
| 265 |
+
elif args.target_prefix_frac is not None:
|
| 266 |
+
bw_rescore_file = target_prefix_frac_rescore_file
|
| 267 |
+
rerank_utils.write_reprocessed(
|
| 268 |
+
gen_output.source,
|
| 269 |
+
gen_output.hypo,
|
| 270 |
+
gen_output.target,
|
| 271 |
+
pre_gen
|
| 272 |
+
+ target_prefix_frac_rescore_file
|
| 273 |
+
+ "."
|
| 274 |
+
+ args.source_lang,
|
| 275 |
+
pre_gen
|
| 276 |
+
+ target_prefix_frac_rescore_file
|
| 277 |
+
+ "."
|
| 278 |
+
+ args.target_lang,
|
| 279 |
+
pre_gen + "/reference_file",
|
| 280 |
+
bpe_symbol=args.remove_bpe,
|
| 281 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 282 |
+
)
|
| 283 |
+
else:
|
| 284 |
+
bw_rescore_file = rescore_file
|
| 285 |
+
|
| 286 |
+
if args.source_prefix_frac is not None:
|
| 287 |
+
fw_rescore_file = source_prefix_frac_rescore_file
|
| 288 |
+
rerank_utils.write_reprocessed(
|
| 289 |
+
gen_output.source,
|
| 290 |
+
gen_output.hypo,
|
| 291 |
+
gen_output.target,
|
| 292 |
+
pre_gen
|
| 293 |
+
+ source_prefix_frac_rescore_file
|
| 294 |
+
+ "."
|
| 295 |
+
+ args.source_lang,
|
| 296 |
+
pre_gen
|
| 297 |
+
+ source_prefix_frac_rescore_file
|
| 298 |
+
+ "."
|
| 299 |
+
+ args.target_lang,
|
| 300 |
+
pre_gen + "/reference_file",
|
| 301 |
+
bpe_symbol=args.remove_bpe,
|
| 302 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 303 |
+
)
|
| 304 |
+
else:
|
| 305 |
+
fw_rescore_file = rescore_file
|
| 306 |
+
|
| 307 |
+
if args.right_to_left1 or args.right_to_left2:
|
| 308 |
+
rerank_utils.write_reprocessed(
|
| 309 |
+
gen_output.source,
|
| 310 |
+
gen_output.hypo,
|
| 311 |
+
gen_output.target,
|
| 312 |
+
pre_gen + "/right_to_left_rescore_data." + args.source_lang,
|
| 313 |
+
pre_gen + "/right_to_left_rescore_data." + args.target_lang,
|
| 314 |
+
pre_gen + "/right_to_left_reference_file",
|
| 315 |
+
right_to_left=True,
|
| 316 |
+
bpe_symbol=args.remove_bpe,
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
print("STEP 3: binarize the translations")
|
| 320 |
+
if (
|
| 321 |
+
not args.right_to_left1
|
| 322 |
+
or args.score_model2 is not None
|
| 323 |
+
and not args.right_to_left2
|
| 324 |
+
or not rerank1_is_gen
|
| 325 |
+
):
|
| 326 |
+
|
| 327 |
+
if args.backwards1 or args.backwards2:
|
| 328 |
+
if args.backwards_score_dict_dir is not None:
|
| 329 |
+
bw_dict = args.backwards_score_dict_dir
|
| 330 |
+
else:
|
| 331 |
+
bw_dict = args.score_dict_dir
|
| 332 |
+
bw_preprocess_param = [
|
| 333 |
+
"--source-lang",
|
| 334 |
+
scorer1_src,
|
| 335 |
+
"--target-lang",
|
| 336 |
+
scorer1_tgt,
|
| 337 |
+
"--trainpref",
|
| 338 |
+
pre_gen + bw_rescore_file,
|
| 339 |
+
"--srcdict",
|
| 340 |
+
bw_dict + "/dict." + scorer1_src + ".txt",
|
| 341 |
+
"--tgtdict",
|
| 342 |
+
bw_dict + "/dict." + scorer1_tgt + ".txt",
|
| 343 |
+
"--destdir",
|
| 344 |
+
backwards_preprocessed_dir,
|
| 345 |
+
]
|
| 346 |
+
preprocess_parser = options.get_preprocessing_parser()
|
| 347 |
+
input_args = preprocess_parser.parse_args(bw_preprocess_param)
|
| 348 |
+
preprocess.main(input_args)
|
| 349 |
+
|
| 350 |
+
preprocess_param = [
|
| 351 |
+
"--source-lang",
|
| 352 |
+
scorer1_src,
|
| 353 |
+
"--target-lang",
|
| 354 |
+
scorer1_tgt,
|
| 355 |
+
"--trainpref",
|
| 356 |
+
pre_gen + fw_rescore_file,
|
| 357 |
+
"--srcdict",
|
| 358 |
+
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
|
| 359 |
+
"--tgtdict",
|
| 360 |
+
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
|
| 361 |
+
"--destdir",
|
| 362 |
+
left_to_right_preprocessed_dir,
|
| 363 |
+
]
|
| 364 |
+
preprocess_parser = options.get_preprocessing_parser()
|
| 365 |
+
input_args = preprocess_parser.parse_args(preprocess_param)
|
| 366 |
+
preprocess.main(input_args)
|
| 367 |
+
|
| 368 |
+
if args.right_to_left1 or args.right_to_left2:
|
| 369 |
+
preprocess_param = [
|
| 370 |
+
"--source-lang",
|
| 371 |
+
scorer1_src,
|
| 372 |
+
"--target-lang",
|
| 373 |
+
scorer1_tgt,
|
| 374 |
+
"--trainpref",
|
| 375 |
+
pre_gen + "/right_to_left_rescore_data",
|
| 376 |
+
"--srcdict",
|
| 377 |
+
args.score_dict_dir + "/dict." + scorer1_src + ".txt",
|
| 378 |
+
"--tgtdict",
|
| 379 |
+
args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
|
| 380 |
+
"--destdir",
|
| 381 |
+
right_to_left_preprocessed_dir,
|
| 382 |
+
]
|
| 383 |
+
preprocess_parser = options.get_preprocessing_parser()
|
| 384 |
+
input_args = preprocess_parser.parse_args(preprocess_param)
|
| 385 |
+
preprocess.main(input_args)
|
| 386 |
+
|
| 387 |
+
return gen_output
|
| 388 |
+
|
| 389 |
+
|
| 390 |
+
def cli_main():
|
| 391 |
+
parser = rerank_options.get_reranking_parser()
|
| 392 |
+
args = options.parse_args_and_arch(parser)
|
| 393 |
+
gen_and_reprocess_nbest(args)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
if __name__ == "__main__":
|
| 397 |
+
cli_main()
|
fairseq-0.10.2/examples/noisychannel/rerank_options.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from fairseq import options
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_reranking_parser(default_task="translation"):
|
| 10 |
+
parser = options.get_parser("Generation and reranking", default_task)
|
| 11 |
+
add_reranking_args(parser)
|
| 12 |
+
return parser
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_tuning_parser(default_task="translation"):
|
| 16 |
+
parser = options.get_parser("Reranking tuning", default_task)
|
| 17 |
+
add_reranking_args(parser)
|
| 18 |
+
add_tuning_args(parser)
|
| 19 |
+
return parser
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def add_reranking_args(parser):
|
| 23 |
+
group = parser.add_argument_group("Reranking")
|
| 24 |
+
# fmt: off
|
| 25 |
+
group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True,
|
| 26 |
+
help='path to first model or ensemble of models for rescoring')
|
| 27 |
+
group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False,
|
| 28 |
+
help='path to second model or ensemble of models for rescoring')
|
| 29 |
+
group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10,
|
| 30 |
+
help='the number of candidate hypothesis to rescore')
|
| 31 |
+
group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128,
|
| 32 |
+
help='batch size for generating the nbest list')
|
| 33 |
+
group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'],
|
| 34 |
+
help='data subset to generate (train, valid, test)')
|
| 35 |
+
group.add_argument('--gen-model', default=None, metavar='FILE',
|
| 36 |
+
help='the model to generate translations')
|
| 37 |
+
group.add_argument('-b1', '--backwards1', action='store_true',
|
| 38 |
+
help='whether or not the first model group is backwards')
|
| 39 |
+
group.add_argument('-b2', '--backwards2', action='store_true',
|
| 40 |
+
help='whether or not the second model group is backwards')
|
| 41 |
+
group.add_argument('-a', '--weight1', default=1, nargs='+', type=float,
|
| 42 |
+
help='the weight(s) of the first model')
|
| 43 |
+
group.add_argument('-b', '--weight2', default=1, nargs='+', type=float,
|
| 44 |
+
help='the weight(s) of the second model, or the gen model if using nbest from interactive.py')
|
| 45 |
+
group.add_argument('-c', '--weight3', default=1, nargs='+', type=float,
|
| 46 |
+
help='the weight(s) of the third model')
|
| 47 |
+
|
| 48 |
+
# lm arguments
|
| 49 |
+
group.add_argument('-lm', '--language-model', default=None, metavar='FILE',
|
| 50 |
+
help='language model for target language to rescore translations')
|
| 51 |
+
group.add_argument('--lm-dict', default=None, metavar='FILE',
|
| 52 |
+
help='the dict of the language model for the target language')
|
| 53 |
+
group.add_argument('--lm-name', default=None,
|
| 54 |
+
help='the name of the language model for the target language')
|
| 55 |
+
group.add_argument('--lm-bpe-code', default=None, metavar='FILE',
|
| 56 |
+
help='the bpe code for the language model for the target language')
|
| 57 |
+
group.add_argument('--data-dir-name', default=None,
|
| 58 |
+
help='name of data directory')
|
| 59 |
+
group.add_argument('--lenpen', default=1, nargs='+', type=float,
|
| 60 |
+
help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
|
| 61 |
+
group.add_argument('--score-dict-dir', default=None,
|
| 62 |
+
help='the directory with dictionaries for the scoring models')
|
| 63 |
+
group.add_argument('--right-to-left1', action='store_true',
|
| 64 |
+
help='whether the first model group is a right to left model')
|
| 65 |
+
group.add_argument('--right-to-left2', action='store_true',
|
| 66 |
+
help='whether the second model group is a right to left model')
|
| 67 |
+
group.add_argument('--remove-bpe', '--post-process', default='@@ ',
|
| 68 |
+
help='the bpe symbol, used for the bitext and LM')
|
| 69 |
+
group.add_argument('--prefix-len', default=None, type=int,
|
| 70 |
+
help='the length of the target prefix to use in rescoring (in terms of words wo bpe)')
|
| 71 |
+
group.add_argument('--sampling', action='store_true',
|
| 72 |
+
help='use sampling instead of beam search for generating n best list')
|
| 73 |
+
group.add_argument('--diff-bpe', action='store_true',
|
| 74 |
+
help='bpe for rescoring and nbest list not the same')
|
| 75 |
+
group.add_argument('--rescore-bpe-code', default=None,
|
| 76 |
+
help='bpe code for rescoring models')
|
| 77 |
+
group.add_argument('--nbest-list', default=None,
|
| 78 |
+
help='use predefined nbest list in interactive.py format')
|
| 79 |
+
group.add_argument('--write-hypos', default=None,
|
| 80 |
+
help='filename prefix to write hypos to')
|
| 81 |
+
group.add_argument('--ref-translation', default=None,
|
| 82 |
+
help='reference translation to use with nbest list from interactive.py')
|
| 83 |
+
group.add_argument('--backwards-score-dict-dir', default=None,
|
| 84 |
+
help='the directory with dictionaries for the backwards model,'
|
| 85 |
+
'if None then it is assumed the fw and backwards models share dictionaries')
|
| 86 |
+
|
| 87 |
+
# extra scaling args
|
| 88 |
+
group.add_argument('--gen-model-name', default=None,
|
| 89 |
+
help='the name of the models that generated the nbest list')
|
| 90 |
+
group.add_argument('--model1-name', default=None,
|
| 91 |
+
help='the name of the set for model1 group ')
|
| 92 |
+
group.add_argument('--model2-name', default=None,
|
| 93 |
+
help='the name of the set for model2 group')
|
| 94 |
+
group.add_argument('--shard-id', default=0, type=int,
|
| 95 |
+
help='the id of the shard to generate')
|
| 96 |
+
group.add_argument('--num-shards', default=1, type=int,
|
| 97 |
+
help='the number of shards to generate across')
|
| 98 |
+
group.add_argument('--all-shards', action='store_true',
|
| 99 |
+
help='use all shards')
|
| 100 |
+
group.add_argument('--target-prefix-frac', default=None, type=float,
|
| 101 |
+
help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)')
|
| 102 |
+
group.add_argument('--source-prefix-frac', default=None, type=float,
|
| 103 |
+
help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)')
|
| 104 |
+
group.add_argument('--normalize', action='store_true',
|
| 105 |
+
help='whether to normalize by src and target len')
|
| 106 |
+
# fmt: on
|
| 107 |
+
return group
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def add_tuning_args(parser):
|
| 111 |
+
group = parser.add_argument_group("Tuning")
|
| 112 |
+
|
| 113 |
+
group.add_argument(
|
| 114 |
+
"--lower-bound",
|
| 115 |
+
default=[-0.7],
|
| 116 |
+
nargs="+",
|
| 117 |
+
type=float,
|
| 118 |
+
help="lower bound of search space",
|
| 119 |
+
)
|
| 120 |
+
group.add_argument(
|
| 121 |
+
"--upper-bound",
|
| 122 |
+
default=[3],
|
| 123 |
+
nargs="+",
|
| 124 |
+
type=float,
|
| 125 |
+
help="upper bound of search space",
|
| 126 |
+
)
|
| 127 |
+
group.add_argument(
|
| 128 |
+
"--tune-param",
|
| 129 |
+
default=["lenpen"],
|
| 130 |
+
nargs="+",
|
| 131 |
+
choices=["lenpen", "weight1", "weight2", "weight3"],
|
| 132 |
+
help="the parameter(s) to tune",
|
| 133 |
+
)
|
| 134 |
+
group.add_argument(
|
| 135 |
+
"--tune-subset",
|
| 136 |
+
default="valid",
|
| 137 |
+
choices=["valid", "test", "train"],
|
| 138 |
+
help="the subset to tune on ",
|
| 139 |
+
)
|
| 140 |
+
group.add_argument(
|
| 141 |
+
"--num-trials",
|
| 142 |
+
default=1000,
|
| 143 |
+
type=int,
|
| 144 |
+
help="number of trials to do for random search",
|
| 145 |
+
)
|
| 146 |
+
group.add_argument(
|
| 147 |
+
"--share-weights", action="store_true", help="share weight2 and weight 3"
|
| 148 |
+
)
|
| 149 |
+
return group
|
fairseq-0.10.2/examples/noisychannel/rerank_score_bw.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from contextlib import redirect_stdout
|
| 8 |
+
|
| 9 |
+
from fairseq import options
|
| 10 |
+
from fairseq_cli import generate
|
| 11 |
+
|
| 12 |
+
from . import rerank_options, rerank_utils
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def score_bw(args):
|
| 16 |
+
if args.backwards1:
|
| 17 |
+
scorer1_src = args.target_lang
|
| 18 |
+
scorer1_tgt = args.source_lang
|
| 19 |
+
else:
|
| 20 |
+
scorer1_src = args.source_lang
|
| 21 |
+
scorer1_tgt = args.target_lang
|
| 22 |
+
|
| 23 |
+
if args.score_model2 is not None:
|
| 24 |
+
if args.backwards2:
|
| 25 |
+
scorer2_src = args.target_lang
|
| 26 |
+
scorer2_tgt = args.source_lang
|
| 27 |
+
else:
|
| 28 |
+
scorer2_src = args.source_lang
|
| 29 |
+
scorer2_tgt = args.target_lang
|
| 30 |
+
|
| 31 |
+
rerank1_is_gen = (
|
| 32 |
+
args.gen_model == args.score_model1 and args.source_prefix_frac is None
|
| 33 |
+
)
|
| 34 |
+
rerank2_is_gen = (
|
| 35 |
+
args.gen_model == args.score_model2 and args.source_prefix_frac is None
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
(
|
| 39 |
+
pre_gen,
|
| 40 |
+
left_to_right_preprocessed_dir,
|
| 41 |
+
right_to_left_preprocessed_dir,
|
| 42 |
+
backwards_preprocessed_dir,
|
| 43 |
+
lm_preprocessed_dir,
|
| 44 |
+
) = rerank_utils.get_directories(
|
| 45 |
+
args.data_dir_name,
|
| 46 |
+
args.num_rescore,
|
| 47 |
+
args.gen_subset,
|
| 48 |
+
args.gen_model_name,
|
| 49 |
+
args.shard_id,
|
| 50 |
+
args.num_shards,
|
| 51 |
+
args.sampling,
|
| 52 |
+
args.prefix_len,
|
| 53 |
+
args.target_prefix_frac,
|
| 54 |
+
args.source_prefix_frac,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
score1_file = rerank_utils.rescore_file_name(
|
| 58 |
+
pre_gen,
|
| 59 |
+
args.prefix_len,
|
| 60 |
+
args.model1_name,
|
| 61 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 62 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 63 |
+
backwards=args.backwards1,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if args.score_model2 is not None:
|
| 67 |
+
score2_file = rerank_utils.rescore_file_name(
|
| 68 |
+
pre_gen,
|
| 69 |
+
args.prefix_len,
|
| 70 |
+
args.model2_name,
|
| 71 |
+
target_prefix_frac=args.target_prefix_frac,
|
| 72 |
+
source_prefix_frac=args.source_prefix_frac,
|
| 73 |
+
backwards=args.backwards2,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
if args.right_to_left1:
|
| 77 |
+
rerank_data1 = right_to_left_preprocessed_dir
|
| 78 |
+
elif args.backwards1:
|
| 79 |
+
rerank_data1 = backwards_preprocessed_dir
|
| 80 |
+
else:
|
| 81 |
+
rerank_data1 = left_to_right_preprocessed_dir
|
| 82 |
+
|
| 83 |
+
gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
|
| 84 |
+
if not rerank1_is_gen and not os.path.isfile(score1_file):
|
| 85 |
+
print("STEP 4: score the translations for model 1")
|
| 86 |
+
|
| 87 |
+
model_param1 = [
|
| 88 |
+
"--path",
|
| 89 |
+
args.score_model1,
|
| 90 |
+
"--source-lang",
|
| 91 |
+
scorer1_src,
|
| 92 |
+
"--target-lang",
|
| 93 |
+
scorer1_tgt,
|
| 94 |
+
]
|
| 95 |
+
gen_model1_param = [rerank_data1] + gen_param + model_param1
|
| 96 |
+
|
| 97 |
+
gen_parser = options.get_generation_parser()
|
| 98 |
+
input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
|
| 99 |
+
|
| 100 |
+
with open(score1_file, "w") as f:
|
| 101 |
+
with redirect_stdout(f):
|
| 102 |
+
generate.main(input_args)
|
| 103 |
+
|
| 104 |
+
if (
|
| 105 |
+
args.score_model2 is not None
|
| 106 |
+
and not os.path.isfile(score2_file)
|
| 107 |
+
and not rerank2_is_gen
|
| 108 |
+
):
|
| 109 |
+
print("STEP 4: score the translations for model 2")
|
| 110 |
+
|
| 111 |
+
if args.right_to_left2:
|
| 112 |
+
rerank_data2 = right_to_left_preprocessed_dir
|
| 113 |
+
elif args.backwards2:
|
| 114 |
+
rerank_data2 = backwards_preprocessed_dir
|
| 115 |
+
else:
|
| 116 |
+
rerank_data2 = left_to_right_preprocessed_dir
|
| 117 |
+
|
| 118 |
+
model_param2 = [
|
| 119 |
+
"--path",
|
| 120 |
+
args.score_model2,
|
| 121 |
+
"--source-lang",
|
| 122 |
+
scorer2_src,
|
| 123 |
+
"--target-lang",
|
| 124 |
+
scorer2_tgt,
|
| 125 |
+
]
|
| 126 |
+
gen_model2_param = [rerank_data2] + gen_param + model_param2
|
| 127 |
+
|
| 128 |
+
gen_parser = options.get_generation_parser()
|
| 129 |
+
input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
|
| 130 |
+
|
| 131 |
+
with open(score2_file, "w") as f:
|
| 132 |
+
with redirect_stdout(f):
|
| 133 |
+
generate.main(input_args)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def cli_main():
|
| 137 |
+
parser = rerank_options.get_reranking_parser()
|
| 138 |
+
args = options.parse_args_and_arch(parser)
|
| 139 |
+
score_bw(args)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
cli_main()
|
fairseq-0.10.2/examples/noisychannel/rerank_tune.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
from fairseq import options
|
| 11 |
+
|
| 12 |
+
from . import rerank, rerank_options
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def random_search(args):
|
| 16 |
+
param_values = []
|
| 17 |
+
tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
|
| 18 |
+
initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
|
| 19 |
+
for i, elem in enumerate(initial_params):
|
| 20 |
+
if type(elem) is not list:
|
| 21 |
+
initial_params[i] = [elem]
|
| 22 |
+
else:
|
| 23 |
+
initial_params[i] = elem
|
| 24 |
+
|
| 25 |
+
tune_parameters = args.tune_param.copy()
|
| 26 |
+
for i in range(len(args.tune_param)):
|
| 27 |
+
assert args.upper_bound[i] >= args.lower_bound[i]
|
| 28 |
+
index = tuneable_parameters.index(args.tune_param[i])
|
| 29 |
+
del tuneable_parameters[index]
|
| 30 |
+
del initial_params[index]
|
| 31 |
+
|
| 32 |
+
tune_parameters += tuneable_parameters
|
| 33 |
+
param_values += initial_params
|
| 34 |
+
random.seed(args.seed)
|
| 35 |
+
|
| 36 |
+
random_params = np.array(
|
| 37 |
+
[
|
| 38 |
+
[
|
| 39 |
+
random.uniform(args.lower_bound[i], args.upper_bound[i])
|
| 40 |
+
for i in range(len(args.tune_param))
|
| 41 |
+
]
|
| 42 |
+
for k in range(args.num_trials)
|
| 43 |
+
]
|
| 44 |
+
)
|
| 45 |
+
set_params = np.array(
|
| 46 |
+
[
|
| 47 |
+
[initial_params[i][0] for i in range(len(tuneable_parameters))]
|
| 48 |
+
for k in range(args.num_trials)
|
| 49 |
+
]
|
| 50 |
+
)
|
| 51 |
+
random_params = np.concatenate((random_params, set_params), 1)
|
| 52 |
+
|
| 53 |
+
rerank_args = vars(args).copy()
|
| 54 |
+
if args.nbest_list:
|
| 55 |
+
rerank_args["gen_subset"] = "test"
|
| 56 |
+
else:
|
| 57 |
+
rerank_args["gen_subset"] = args.tune_subset
|
| 58 |
+
|
| 59 |
+
for k in range(len(tune_parameters)):
|
| 60 |
+
rerank_args[tune_parameters[k]] = list(random_params[:, k])
|
| 61 |
+
|
| 62 |
+
if args.share_weights:
|
| 63 |
+
k = tune_parameters.index("weight2")
|
| 64 |
+
rerank_args["weight3"] = list(random_params[:, k])
|
| 65 |
+
|
| 66 |
+
rerank_args = argparse.Namespace(**rerank_args)
|
| 67 |
+
best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
|
| 68 |
+
rerank_args
|
| 69 |
+
)
|
| 70 |
+
rerank_args = vars(args).copy()
|
| 71 |
+
rerank_args["lenpen"] = [best_lenpen]
|
| 72 |
+
rerank_args["weight1"] = [best_weight1]
|
| 73 |
+
rerank_args["weight2"] = [best_weight2]
|
| 74 |
+
rerank_args["weight3"] = [best_weight3]
|
| 75 |
+
|
| 76 |
+
# write the hypothesis from the valid set from the best trial
|
| 77 |
+
|
| 78 |
+
if args.gen_subset != "valid":
|
| 79 |
+
rerank_args["gen_subset"] = "valid"
|
| 80 |
+
rerank_args = argparse.Namespace(**rerank_args)
|
| 81 |
+
rerank.rerank(rerank_args)
|
| 82 |
+
|
| 83 |
+
# test with the best hyperparameters on gen subset
|
| 84 |
+
rerank_args = vars(args).copy()
|
| 85 |
+
rerank_args["gen_subset"] = args.gen_subset
|
| 86 |
+
rerank_args["lenpen"] = [best_lenpen]
|
| 87 |
+
rerank_args["weight1"] = [best_weight1]
|
| 88 |
+
rerank_args["weight2"] = [best_weight2]
|
| 89 |
+
rerank_args["weight3"] = [best_weight3]
|
| 90 |
+
rerank_args = argparse.Namespace(**rerank_args)
|
| 91 |
+
rerank.rerank(rerank_args)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def cli_main():
|
| 95 |
+
parser = rerank_options.get_tuning_parser()
|
| 96 |
+
args = options.parse_args_and_arch(parser)
|
| 97 |
+
|
| 98 |
+
random_search(args)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
cli_main()
|
fairseq-0.10.2/examples/noisychannel/rerank_utils.py
ADDED
|
@@ -0,0 +1,850 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import subprocess
|
| 10 |
+
from contextlib import redirect_stdout
|
| 11 |
+
|
| 12 |
+
from fairseq import options
|
| 13 |
+
from fairseq_cli import eval_lm, preprocess
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def reprocess(fle):
|
| 17 |
+
# takes in a file of generate.py translation generate_output
|
| 18 |
+
# returns a source dict and hypothesis dict, where keys are the ID num (as a string)
|
| 19 |
+
# and values and the corresponding source and translation. There may be several translations
|
| 20 |
+
# per source, so the values for hypothesis_dict are lists.
|
| 21 |
+
# parses output of generate.py
|
| 22 |
+
|
| 23 |
+
with open(fle, "r") as f:
|
| 24 |
+
txt = f.read()
|
| 25 |
+
|
| 26 |
+
"""reprocess generate.py output"""
|
| 27 |
+
p = re.compile(r"[STHP][-]\d+\s*")
|
| 28 |
+
hp = re.compile(r"(\s*[-]?\d+[.]?\d+\s*)|(\s*(-inf)\s*)")
|
| 29 |
+
source_dict = {}
|
| 30 |
+
hypothesis_dict = {}
|
| 31 |
+
score_dict = {}
|
| 32 |
+
target_dict = {}
|
| 33 |
+
pos_score_dict = {}
|
| 34 |
+
lines = txt.split("\n")
|
| 35 |
+
|
| 36 |
+
for line in lines:
|
| 37 |
+
line += "\n"
|
| 38 |
+
prefix = re.search(p, line)
|
| 39 |
+
if prefix is not None:
|
| 40 |
+
assert len(prefix.group()) > 2, "prefix id not found"
|
| 41 |
+
_, j = prefix.span()
|
| 42 |
+
id_num = prefix.group()[2:]
|
| 43 |
+
id_num = int(id_num)
|
| 44 |
+
line_type = prefix.group()[0]
|
| 45 |
+
if line_type == "H":
|
| 46 |
+
h_txt = line[j:]
|
| 47 |
+
hypo = re.search(hp, h_txt)
|
| 48 |
+
assert (
|
| 49 |
+
hypo is not None
|
| 50 |
+
), "regular expression failed to find the hypothesis scoring"
|
| 51 |
+
_, i = hypo.span()
|
| 52 |
+
score = hypo.group()
|
| 53 |
+
if id_num in hypothesis_dict:
|
| 54 |
+
hypothesis_dict[id_num].append(h_txt[i:])
|
| 55 |
+
score_dict[id_num].append(float(score))
|
| 56 |
+
else:
|
| 57 |
+
hypothesis_dict[id_num] = [h_txt[i:]]
|
| 58 |
+
score_dict[id_num] = [float(score)]
|
| 59 |
+
|
| 60 |
+
elif line_type == "S":
|
| 61 |
+
source_dict[id_num] = line[j:]
|
| 62 |
+
elif line_type == "T":
|
| 63 |
+
target_dict[id_num] = line[j:]
|
| 64 |
+
elif line_type == "P":
|
| 65 |
+
pos_scores = (line[j:]).split()
|
| 66 |
+
pos_scores = [float(x) for x in pos_scores]
|
| 67 |
+
if id_num in pos_score_dict:
|
| 68 |
+
pos_score_dict[id_num].append(pos_scores)
|
| 69 |
+
else:
|
| 70 |
+
pos_score_dict[id_num] = [pos_scores]
|
| 71 |
+
|
| 72 |
+
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def reprocess_nbest(fle):
|
| 76 |
+
"""reprocess interactive.py output"""
|
| 77 |
+
with open(fle, "r") as f:
|
| 78 |
+
txt = f.read()
|
| 79 |
+
|
| 80 |
+
source_dict = {}
|
| 81 |
+
hypothesis_dict = {}
|
| 82 |
+
score_dict = {}
|
| 83 |
+
target_dict = {}
|
| 84 |
+
pos_score_dict = {}
|
| 85 |
+
lines = txt.split("\n")
|
| 86 |
+
|
| 87 |
+
hp = re.compile(r"[-]?\d+[.]?\d+")
|
| 88 |
+
j = -1
|
| 89 |
+
|
| 90 |
+
for _i, line in enumerate(lines):
|
| 91 |
+
line += "\n"
|
| 92 |
+
line_type = line[0]
|
| 93 |
+
|
| 94 |
+
if line_type == "H":
|
| 95 |
+
hypo = re.search(hp, line)
|
| 96 |
+
_, start_index = hypo.span()
|
| 97 |
+
score = hypo.group()
|
| 98 |
+
if j in score_dict:
|
| 99 |
+
score_dict[j].append(float(score))
|
| 100 |
+
hypothesis_dict[j].append(line[start_index:].strip("\t"))
|
| 101 |
+
else:
|
| 102 |
+
score_dict[j] = [float(score)]
|
| 103 |
+
hypothesis_dict[j] = [line[start_index:].strip("\t")]
|
| 104 |
+
elif line_type == "O":
|
| 105 |
+
j += 1
|
| 106 |
+
source_dict[j] = line[2:]
|
| 107 |
+
# we don't have the targets for interactive.py
|
| 108 |
+
target_dict[j] = "filler"
|
| 109 |
+
|
| 110 |
+
elif line_type == "P":
|
| 111 |
+
pos_scores = [float(pos_score) for pos_score in line.split()[1:]]
|
| 112 |
+
if j in pos_score_dict:
|
| 113 |
+
pos_score_dict[j].append(pos_scores)
|
| 114 |
+
else:
|
| 115 |
+
pos_score_dict[j] = [pos_scores]
|
| 116 |
+
|
| 117 |
+
assert source_dict.keys() == hypothesis_dict.keys()
|
| 118 |
+
assert source_dict.keys() == pos_score_dict.keys()
|
| 119 |
+
assert source_dict.keys() == score_dict.keys()
|
| 120 |
+
|
| 121 |
+
return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def write_reprocessed(
|
| 125 |
+
sources,
|
| 126 |
+
hypos,
|
| 127 |
+
targets,
|
| 128 |
+
source_outfile,
|
| 129 |
+
hypo_outfile,
|
| 130 |
+
target_outfile,
|
| 131 |
+
right_to_left=False,
|
| 132 |
+
prefix_len=None,
|
| 133 |
+
bpe_symbol=None,
|
| 134 |
+
target_prefix_frac=None,
|
| 135 |
+
source_prefix_frac=None,
|
| 136 |
+
):
|
| 137 |
+
|
| 138 |
+
"""writes nbest hypothesis for rescoring"""
|
| 139 |
+
assert not (
|
| 140 |
+
prefix_len is not None and target_prefix_frac is not None
|
| 141 |
+
), "in writing reprocessed, only one type of prefix may be used"
|
| 142 |
+
assert not (
|
| 143 |
+
prefix_len is not None and source_prefix_frac is not None
|
| 144 |
+
), "in writing reprocessed, only one type of prefix may be used"
|
| 145 |
+
assert not (
|
| 146 |
+
target_prefix_frac is not None and source_prefix_frac is not None
|
| 147 |
+
), "in writing reprocessed, only one type of prefix may be used"
|
| 148 |
+
|
| 149 |
+
with open(source_outfile, "w") as source_file, open(
|
| 150 |
+
hypo_outfile, "w"
|
| 151 |
+
) as hypo_file, open(target_outfile, "w") as target_file:
|
| 152 |
+
|
| 153 |
+
assert len(sources) == len(hypos), "sources and hypos list length mismatch"
|
| 154 |
+
if right_to_left:
|
| 155 |
+
for i in range(len(sources)):
|
| 156 |
+
for j in range(len(hypos[i])):
|
| 157 |
+
if prefix_len is None:
|
| 158 |
+
hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
|
| 159 |
+
else:
|
| 160 |
+
raise NotImplementedError()
|
| 161 |
+
source_file.write(make_right_to_left(sources[i]) + "\n")
|
| 162 |
+
target_file.write(make_right_to_left(targets[i]) + "\n")
|
| 163 |
+
else:
|
| 164 |
+
for i in sorted(sources.keys()):
|
| 165 |
+
for j in range(len(hypos[i])):
|
| 166 |
+
if prefix_len is not None:
|
| 167 |
+
shortened = (
|
| 168 |
+
get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
|
| 169 |
+
+ "\n"
|
| 170 |
+
)
|
| 171 |
+
hypo_file.write(shortened)
|
| 172 |
+
source_file.write(sources[i])
|
| 173 |
+
target_file.write(targets[i])
|
| 174 |
+
elif target_prefix_frac is not None:
|
| 175 |
+
num_words, shortened, num_bpe_tokens = calc_length_from_frac(
|
| 176 |
+
hypos[i][j], target_prefix_frac, bpe_symbol
|
| 177 |
+
)
|
| 178 |
+
shortened += "\n"
|
| 179 |
+
hypo_file.write(shortened)
|
| 180 |
+
source_file.write(sources[i])
|
| 181 |
+
target_file.write(targets[i])
|
| 182 |
+
elif source_prefix_frac is not None:
|
| 183 |
+
num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
|
| 184 |
+
sources[i], source_prefix_frac, bpe_symbol
|
| 185 |
+
)
|
| 186 |
+
shortened += "\n"
|
| 187 |
+
hypo_file.write(hypos[i][j])
|
| 188 |
+
source_file.write(shortened)
|
| 189 |
+
target_file.write(targets[i])
|
| 190 |
+
else:
|
| 191 |
+
hypo_file.write(hypos[i][j])
|
| 192 |
+
source_file.write(sources[i])
|
| 193 |
+
target_file.write(targets[i])
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
|
| 197 |
+
# return number of words, (not bpe tokens) that we want
|
| 198 |
+
no_bpe_sen = remove_bpe(bpe_sentence, bpe_symbol)
|
| 199 |
+
len_sen = len(no_bpe_sen.split())
|
| 200 |
+
|
| 201 |
+
num_words = math.ceil(len_sen * prefix_frac)
|
| 202 |
+
prefix = get_prefix_no_bpe(bpe_sentence, bpe_symbol, num_words)
|
| 203 |
+
num_bpe_tokens = len(prefix.split())
|
| 204 |
+
return num_words, prefix, num_bpe_tokens
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def get_prefix(sentence, prefix_len):
|
| 208 |
+
"""assuming no bpe, gets the prefix of the sentence with prefix_len words"""
|
| 209 |
+
tokens = sentence.strip("\n").split()
|
| 210 |
+
if prefix_len >= len(tokens):
|
| 211 |
+
return sentence.strip("\n")
|
| 212 |
+
else:
|
| 213 |
+
return " ".join(tokens[:prefix_len])
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_prefix_no_bpe(sentence, bpe_symbol, prefix_len):
|
| 217 |
+
if bpe_symbol is None:
|
| 218 |
+
return get_prefix(sentence, prefix_len)
|
| 219 |
+
else:
|
| 220 |
+
return " ".join(get_prefix_from_len(sentence.split(), bpe_symbol, prefix_len))
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
|
| 224 |
+
"""get the prefix of sentence with bpe, with prefix len in terms of words, not bpe tokens"""
|
| 225 |
+
bpe_count = sum([bpe_symbol.strip(" ") in t for t in sentence[:prefix_len]])
|
| 226 |
+
if bpe_count == 0:
|
| 227 |
+
return sentence[:prefix_len]
|
| 228 |
+
else:
|
| 229 |
+
return sentence[:prefix_len] + get_prefix_from_len(
|
| 230 |
+
sentence[prefix_len:], bpe_symbol, bpe_count
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
|
| 235 |
+
"""given a prefix length in terms of words, return the number of bpe tokens"""
|
| 236 |
+
prefix = get_prefix_no_bpe(sentence, bpe_symbol, prefix_len)
|
| 237 |
+
assert len(remove_bpe(prefix, bpe_symbol).split()) <= prefix_len
|
| 238 |
+
return len(prefix.split(" "))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def make_right_to_left(line):
|
| 242 |
+
tokens = line.split()
|
| 243 |
+
tokens.reverse()
|
| 244 |
+
new_line = " ".join(tokens)
|
| 245 |
+
return new_line
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def remove_bpe(line, bpe_symbol):
|
| 249 |
+
line = line.replace("\n", "")
|
| 250 |
+
line = (line + " ").replace(bpe_symbol, "").rstrip()
|
| 251 |
+
return line + ("\n")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def remove_bpe_dict(pred_dict, bpe_symbol):
|
| 255 |
+
new_dict = {}
|
| 256 |
+
for i in pred_dict:
|
| 257 |
+
if type(pred_dict[i]) == list:
|
| 258 |
+
new_list = [remove_bpe(elem, bpe_symbol) for elem in pred_dict[i]]
|
| 259 |
+
new_dict[i] = new_list
|
| 260 |
+
else:
|
| 261 |
+
new_dict[i] = remove_bpe(pred_dict[i], bpe_symbol)
|
| 262 |
+
return new_dict
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def parse_bleu_scoring(line):
|
| 266 |
+
p = re.compile(r"(BLEU4 = )\d+[.]\d+")
|
| 267 |
+
res = re.search(p, line)
|
| 268 |
+
assert res is not None, line
|
| 269 |
+
return float(res.group()[8:])
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def get_full_from_prefix(hypo_prefix, hypos):
|
| 273 |
+
"""given a hypo prefix, recover the first hypo from the list of complete hypos beginning with that prefix"""
|
| 274 |
+
for hypo in hypos:
|
| 275 |
+
hypo_prefix = hypo_prefix.strip("\n")
|
| 276 |
+
len_prefix = len(hypo_prefix)
|
| 277 |
+
if hypo[:len_prefix] == hypo_prefix:
|
| 278 |
+
return hypo
|
| 279 |
+
# no match found
|
| 280 |
+
raise Exception()
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
def get_score(
|
| 284 |
+
a,
|
| 285 |
+
b,
|
| 286 |
+
c,
|
| 287 |
+
target_len,
|
| 288 |
+
bitext_score1,
|
| 289 |
+
bitext_score2=None,
|
| 290 |
+
lm_score=None,
|
| 291 |
+
lenpen=None,
|
| 292 |
+
src_len=None,
|
| 293 |
+
tgt_len=None,
|
| 294 |
+
bitext1_backwards=False,
|
| 295 |
+
bitext2_backwards=False,
|
| 296 |
+
normalize=False,
|
| 297 |
+
):
|
| 298 |
+
if bitext1_backwards:
|
| 299 |
+
bitext1_norm = src_len
|
| 300 |
+
else:
|
| 301 |
+
bitext1_norm = tgt_len
|
| 302 |
+
if bitext_score2 is not None:
|
| 303 |
+
if bitext2_backwards:
|
| 304 |
+
bitext2_norm = src_len
|
| 305 |
+
else:
|
| 306 |
+
bitext2_norm = tgt_len
|
| 307 |
+
else:
|
| 308 |
+
bitext2_norm = 1
|
| 309 |
+
bitext_score2 = 0
|
| 310 |
+
if normalize:
|
| 311 |
+
score = (
|
| 312 |
+
a * bitext_score1 / bitext1_norm
|
| 313 |
+
+ b * bitext_score2 / bitext2_norm
|
| 314 |
+
+ c * lm_score / src_len
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
score = a * bitext_score1 + b * bitext_score2 + c * lm_score
|
| 318 |
+
|
| 319 |
+
if lenpen is not None:
|
| 320 |
+
score /= (target_len) ** float(lenpen)
|
| 321 |
+
|
| 322 |
+
return score
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class BitextOutput(object):
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
output_file,
|
| 329 |
+
backwards,
|
| 330 |
+
right_to_left,
|
| 331 |
+
bpe_symbol,
|
| 332 |
+
prefix_len=None,
|
| 333 |
+
target_prefix_frac=None,
|
| 334 |
+
source_prefix_frac=None,
|
| 335 |
+
):
|
| 336 |
+
"""process output from rescoring"""
|
| 337 |
+
source, hypo, score, target, pos_score = reprocess(output_file)
|
| 338 |
+
if backwards:
|
| 339 |
+
self.hypo_fracs = source_prefix_frac
|
| 340 |
+
else:
|
| 341 |
+
self.hypo_fracs = target_prefix_frac
|
| 342 |
+
|
| 343 |
+
# remove length penalty so we can use raw scores
|
| 344 |
+
score, num_bpe_tokens = get_score_from_pos(
|
| 345 |
+
pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
|
| 346 |
+
)
|
| 347 |
+
source_lengths = {}
|
| 348 |
+
target_lengths = {}
|
| 349 |
+
|
| 350 |
+
assert hypo.keys() == source.keys(), "key mismatch"
|
| 351 |
+
if backwards:
|
| 352 |
+
tmp = hypo
|
| 353 |
+
hypo = source
|
| 354 |
+
source = tmp
|
| 355 |
+
for i in source:
|
| 356 |
+
# since we are reranking, there should only be one hypo per source sentence
|
| 357 |
+
if backwards:
|
| 358 |
+
len_src = len(source[i][0].split())
|
| 359 |
+
# record length without <eos>
|
| 360 |
+
if len_src == num_bpe_tokens[i][0] - 1:
|
| 361 |
+
source_lengths[i] = num_bpe_tokens[i][0] - 1
|
| 362 |
+
else:
|
| 363 |
+
source_lengths[i] = num_bpe_tokens[i][0]
|
| 364 |
+
|
| 365 |
+
target_lengths[i] = len(hypo[i].split())
|
| 366 |
+
|
| 367 |
+
source[i] = remove_bpe(source[i][0], bpe_symbol)
|
| 368 |
+
target[i] = remove_bpe(target[i], bpe_symbol)
|
| 369 |
+
hypo[i] = remove_bpe(hypo[i], bpe_symbol)
|
| 370 |
+
|
| 371 |
+
score[i] = float(score[i][0])
|
| 372 |
+
pos_score[i] = pos_score[i][0]
|
| 373 |
+
|
| 374 |
+
else:
|
| 375 |
+
len_tgt = len(hypo[i][0].split())
|
| 376 |
+
# record length without <eos>
|
| 377 |
+
if len_tgt == num_bpe_tokens[i][0] - 1:
|
| 378 |
+
target_lengths[i] = num_bpe_tokens[i][0] - 1
|
| 379 |
+
else:
|
| 380 |
+
target_lengths[i] = num_bpe_tokens[i][0]
|
| 381 |
+
|
| 382 |
+
source_lengths[i] = len(source[i].split())
|
| 383 |
+
|
| 384 |
+
if right_to_left:
|
| 385 |
+
source[i] = remove_bpe(make_right_to_left(source[i]), bpe_symbol)
|
| 386 |
+
target[i] = remove_bpe(make_right_to_left(target[i]), bpe_symbol)
|
| 387 |
+
hypo[i] = remove_bpe(make_right_to_left(hypo[i][0]), bpe_symbol)
|
| 388 |
+
score[i] = float(score[i][0])
|
| 389 |
+
pos_score[i] = pos_score[i][0]
|
| 390 |
+
else:
|
| 391 |
+
assert (
|
| 392 |
+
len(hypo[i]) == 1
|
| 393 |
+
), "expected only one hypothesis per source sentence"
|
| 394 |
+
source[i] = remove_bpe(source[i], bpe_symbol)
|
| 395 |
+
target[i] = remove_bpe(target[i], bpe_symbol)
|
| 396 |
+
hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
|
| 397 |
+
score[i] = float(score[i][0])
|
| 398 |
+
pos_score[i] = pos_score[i][0]
|
| 399 |
+
|
| 400 |
+
self.rescore_source = source
|
| 401 |
+
self.rescore_hypo = hypo
|
| 402 |
+
self.rescore_score = score
|
| 403 |
+
self.rescore_target = target
|
| 404 |
+
self.rescore_pos_score = pos_score
|
| 405 |
+
self.backwards = backwards
|
| 406 |
+
self.right_to_left = right_to_left
|
| 407 |
+
self.target_lengths = target_lengths
|
| 408 |
+
self.source_lengths = source_lengths
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
class BitextOutputFromGen(object):
|
| 412 |
+
def __init__(
|
| 413 |
+
self,
|
| 414 |
+
predictions_bpe_file,
|
| 415 |
+
bpe_symbol=None,
|
| 416 |
+
nbest=False,
|
| 417 |
+
prefix_len=None,
|
| 418 |
+
target_prefix_frac=None,
|
| 419 |
+
):
|
| 420 |
+
if nbest:
|
| 421 |
+
(
|
| 422 |
+
pred_source,
|
| 423 |
+
pred_hypo,
|
| 424 |
+
pred_score,
|
| 425 |
+
pred_target,
|
| 426 |
+
pred_pos_score,
|
| 427 |
+
) = reprocess_nbest(predictions_bpe_file)
|
| 428 |
+
else:
|
| 429 |
+
pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
|
| 430 |
+
predictions_bpe_file
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
assert len(pred_source) == len(pred_hypo)
|
| 434 |
+
assert len(pred_source) == len(pred_score)
|
| 435 |
+
assert len(pred_source) == len(pred_target)
|
| 436 |
+
assert len(pred_source) == len(pred_pos_score)
|
| 437 |
+
|
| 438 |
+
# remove length penalty so we can use raw scores
|
| 439 |
+
pred_score, num_bpe_tokens = get_score_from_pos(
|
| 440 |
+
pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
self.source = pred_source
|
| 444 |
+
self.target = pred_target
|
| 445 |
+
self.score = pred_score
|
| 446 |
+
self.pos_score = pred_pos_score
|
| 447 |
+
self.hypo = pred_hypo
|
| 448 |
+
self.target_lengths = {}
|
| 449 |
+
self.source_lengths = {}
|
| 450 |
+
|
| 451 |
+
self.no_bpe_source = remove_bpe_dict(pred_source.copy(), bpe_symbol)
|
| 452 |
+
self.no_bpe_hypo = remove_bpe_dict(pred_hypo.copy(), bpe_symbol)
|
| 453 |
+
self.no_bpe_target = remove_bpe_dict(pred_target.copy(), bpe_symbol)
|
| 454 |
+
|
| 455 |
+
# indexes to match those from the rescoring models
|
| 456 |
+
self.rescore_source = {}
|
| 457 |
+
self.rescore_target = {}
|
| 458 |
+
self.rescore_pos_score = {}
|
| 459 |
+
self.rescore_hypo = {}
|
| 460 |
+
self.rescore_score = {}
|
| 461 |
+
self.num_hypos = {}
|
| 462 |
+
self.backwards = False
|
| 463 |
+
self.right_to_left = False
|
| 464 |
+
|
| 465 |
+
index = 0
|
| 466 |
+
|
| 467 |
+
for i in sorted(pred_source.keys()):
|
| 468 |
+
for j in range(len(pred_hypo[i])):
|
| 469 |
+
|
| 470 |
+
self.target_lengths[index] = len(self.hypo[i][j].split())
|
| 471 |
+
self.source_lengths[index] = len(self.source[i].split())
|
| 472 |
+
|
| 473 |
+
self.rescore_source[index] = self.no_bpe_source[i]
|
| 474 |
+
self.rescore_target[index] = self.no_bpe_target[i]
|
| 475 |
+
self.rescore_hypo[index] = self.no_bpe_hypo[i][j]
|
| 476 |
+
self.rescore_score[index] = float(pred_score[i][j])
|
| 477 |
+
self.rescore_pos_score[index] = pred_pos_score[i][j]
|
| 478 |
+
self.num_hypos[index] = len(pred_hypo[i])
|
| 479 |
+
index += 1
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
def get_score_from_pos(
|
| 483 |
+
pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
|
| 484 |
+
):
|
| 485 |
+
score_dict = {}
|
| 486 |
+
num_bpe_tokens_dict = {}
|
| 487 |
+
assert prefix_len is None or hypo_frac is None
|
| 488 |
+
for key in pos_score_dict:
|
| 489 |
+
score_dict[key] = []
|
| 490 |
+
num_bpe_tokens_dict[key] = []
|
| 491 |
+
for i in range(len(pos_score_dict[key])):
|
| 492 |
+
if prefix_len is not None and not backwards:
|
| 493 |
+
num_bpe_tokens = get_num_bpe_tokens_from_len(
|
| 494 |
+
hypo_dict[key][i], bpe_symbol, prefix_len
|
| 495 |
+
)
|
| 496 |
+
score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
|
| 497 |
+
num_bpe_tokens_dict[key].append(num_bpe_tokens)
|
| 498 |
+
elif hypo_frac is not None:
|
| 499 |
+
num_words, shortened, hypo_prefix_len = calc_length_from_frac(
|
| 500 |
+
hypo_dict[key][i], hypo_frac, bpe_symbol
|
| 501 |
+
)
|
| 502 |
+
score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
|
| 503 |
+
num_bpe_tokens_dict[key].append(hypo_prefix_len)
|
| 504 |
+
else:
|
| 505 |
+
score_dict[key].append(sum(pos_score_dict[key][i]))
|
| 506 |
+
num_bpe_tokens_dict[key].append(len(pos_score_dict[key][i]))
|
| 507 |
+
return score_dict, num_bpe_tokens_dict
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
class LMOutput(object):
|
| 511 |
+
def __init__(
|
| 512 |
+
self,
|
| 513 |
+
lm_score_file,
|
| 514 |
+
lm_dict=None,
|
| 515 |
+
prefix_len=None,
|
| 516 |
+
bpe_symbol=None,
|
| 517 |
+
target_prefix_frac=None,
|
| 518 |
+
):
|
| 519 |
+
(
|
| 520 |
+
lm_sentences,
|
| 521 |
+
lm_sen_scores,
|
| 522 |
+
lm_sen_pos_scores,
|
| 523 |
+
lm_no_bpe_sentences,
|
| 524 |
+
lm_bpe_tokens,
|
| 525 |
+
) = parse_lm(
|
| 526 |
+
lm_score_file,
|
| 527 |
+
prefix_len=prefix_len,
|
| 528 |
+
bpe_symbol=bpe_symbol,
|
| 529 |
+
target_prefix_frac=target_prefix_frac,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
self.sentences = lm_sentences
|
| 533 |
+
self.score = lm_sen_scores
|
| 534 |
+
self.pos_score = lm_sen_pos_scores
|
| 535 |
+
self.lm_dict = lm_dict
|
| 536 |
+
self.no_bpe_sentences = lm_no_bpe_sentences
|
| 537 |
+
self.bpe_tokens = lm_bpe_tokens
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
|
| 541 |
+
"""parse output of eval_lm"""
|
| 542 |
+
with open(input_file, "r") as f:
|
| 543 |
+
text = f.readlines()
|
| 544 |
+
text = text[7:]
|
| 545 |
+
cleaned_text = text[:-2]
|
| 546 |
+
|
| 547 |
+
sentences = {}
|
| 548 |
+
sen_scores = {}
|
| 549 |
+
sen_pos_scores = {}
|
| 550 |
+
no_bpe_sentences = {}
|
| 551 |
+
num_bpe_tokens_dict = {}
|
| 552 |
+
for _i, line in enumerate(cleaned_text):
|
| 553 |
+
tokens = line.split()
|
| 554 |
+
if tokens[0].isdigit():
|
| 555 |
+
line_id = int(tokens[0])
|
| 556 |
+
scores = [float(x[1:-1]) for x in tokens[2::2]]
|
| 557 |
+
sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
|
| 558 |
+
if bpe_symbol is not None:
|
| 559 |
+
# exclude <eos> symbol to match output from generate.py
|
| 560 |
+
bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
|
| 561 |
+
no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
|
| 562 |
+
no_bpe_sentences[line_id] = no_bpe_sen
|
| 563 |
+
|
| 564 |
+
if prefix_len is not None:
|
| 565 |
+
num_bpe_tokens = get_num_bpe_tokens_from_len(
|
| 566 |
+
bpe_sen, bpe_symbol, prefix_len
|
| 567 |
+
)
|
| 568 |
+
sen_scores[line_id] = sum(scores[:num_bpe_tokens])
|
| 569 |
+
num_bpe_tokens_dict[line_id] = num_bpe_tokens
|
| 570 |
+
elif target_prefix_frac is not None:
|
| 571 |
+
num_words, shortened, target_prefix_len = calc_length_from_frac(
|
| 572 |
+
bpe_sen, target_prefix_frac, bpe_symbol
|
| 573 |
+
)
|
| 574 |
+
sen_scores[line_id] = sum(scores[:target_prefix_len])
|
| 575 |
+
num_bpe_tokens_dict[line_id] = target_prefix_len
|
| 576 |
+
else:
|
| 577 |
+
sen_scores[line_id] = sum(scores)
|
| 578 |
+
num_bpe_tokens_dict[line_id] = len(scores)
|
| 579 |
+
|
| 580 |
+
sen_pos_scores[line_id] = scores
|
| 581 |
+
|
| 582 |
+
return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
def get_directories(
|
| 586 |
+
data_dir_name,
|
| 587 |
+
num_rescore,
|
| 588 |
+
gen_subset,
|
| 589 |
+
fw_name,
|
| 590 |
+
shard_id,
|
| 591 |
+
num_shards,
|
| 592 |
+
sampling=False,
|
| 593 |
+
prefix_len=None,
|
| 594 |
+
target_prefix_frac=None,
|
| 595 |
+
source_prefix_frac=None,
|
| 596 |
+
):
|
| 597 |
+
nbest_file_id = (
|
| 598 |
+
"nbest_"
|
| 599 |
+
+ str(num_rescore)
|
| 600 |
+
+ "_subset_"
|
| 601 |
+
+ gen_subset
|
| 602 |
+
+ "_fw_name_"
|
| 603 |
+
+ fw_name
|
| 604 |
+
+ "_shard_"
|
| 605 |
+
+ str(shard_id)
|
| 606 |
+
+ "_of_"
|
| 607 |
+
+ str(num_shards)
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
if sampling:
|
| 611 |
+
nbest_file_id += "_sampling"
|
| 612 |
+
|
| 613 |
+
# the directory containing all information for this nbest list
|
| 614 |
+
pre_gen = (
|
| 615 |
+
os.path.join(os.path.dirname(__file__))
|
| 616 |
+
+ "/rerank_data/"
|
| 617 |
+
+ data_dir_name
|
| 618 |
+
+ "/"
|
| 619 |
+
+ nbest_file_id
|
| 620 |
+
)
|
| 621 |
+
# the directory to store the preprocessed nbest list, for left to right rescoring
|
| 622 |
+
left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
|
| 623 |
+
if source_prefix_frac is not None:
|
| 624 |
+
left_to_right_preprocessed_dir = (
|
| 625 |
+
left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
|
| 626 |
+
)
|
| 627 |
+
# the directory to store the preprocessed nbest list, for right to left rescoring
|
| 628 |
+
right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
|
| 629 |
+
# the directory to store the preprocessed nbest list, for backwards rescoring
|
| 630 |
+
backwards_preprocessed_dir = pre_gen + "/backwards"
|
| 631 |
+
if target_prefix_frac is not None:
|
| 632 |
+
backwards_preprocessed_dir = (
|
| 633 |
+
backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
|
| 634 |
+
)
|
| 635 |
+
elif prefix_len is not None:
|
| 636 |
+
backwards_preprocessed_dir = (
|
| 637 |
+
backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# the directory to store the preprocessed nbest list, for rescoring with P(T)
|
| 641 |
+
lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
|
| 642 |
+
|
| 643 |
+
return (
|
| 644 |
+
pre_gen,
|
| 645 |
+
left_to_right_preprocessed_dir,
|
| 646 |
+
right_to_left_preprocessed_dir,
|
| 647 |
+
backwards_preprocessed_dir,
|
| 648 |
+
lm_preprocessed_dir,
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def lm_scoring(
|
| 653 |
+
preprocess_directory,
|
| 654 |
+
bpe_status,
|
| 655 |
+
gen_output,
|
| 656 |
+
pre_gen,
|
| 657 |
+
cur_lm_dict,
|
| 658 |
+
cur_lm_name,
|
| 659 |
+
cur_language_model,
|
| 660 |
+
cur_lm_bpe_code,
|
| 661 |
+
batch_size,
|
| 662 |
+
lm_score_file,
|
| 663 |
+
target_lang,
|
| 664 |
+
source_lang,
|
| 665 |
+
prefix_len=None,
|
| 666 |
+
):
|
| 667 |
+
if prefix_len is not None:
|
| 668 |
+
assert (
|
| 669 |
+
bpe_status == "different"
|
| 670 |
+
), "bpe status must be different to use prefix len"
|
| 671 |
+
if bpe_status == "no bpe":
|
| 672 |
+
# run lm on output without bpe
|
| 673 |
+
write_reprocessed(
|
| 674 |
+
gen_output.no_bpe_source,
|
| 675 |
+
gen_output.no_bpe_hypo,
|
| 676 |
+
gen_output.no_bpe_target,
|
| 677 |
+
pre_gen + "/rescore_data_no_bpe.de",
|
| 678 |
+
pre_gen + "/rescore_data_no_bpe.en",
|
| 679 |
+
pre_gen + "/reference_file_no_bpe",
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
preprocess_lm_param = [
|
| 683 |
+
"--only-source",
|
| 684 |
+
"--trainpref",
|
| 685 |
+
pre_gen + "/rescore_data_no_bpe." + target_lang,
|
| 686 |
+
"--srcdict",
|
| 687 |
+
cur_lm_dict,
|
| 688 |
+
"--destdir",
|
| 689 |
+
preprocess_directory,
|
| 690 |
+
]
|
| 691 |
+
preprocess_parser = options.get_preprocessing_parser()
|
| 692 |
+
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
| 693 |
+
preprocess.main(input_args)
|
| 694 |
+
|
| 695 |
+
eval_lm_param = [
|
| 696 |
+
preprocess_directory,
|
| 697 |
+
"--path",
|
| 698 |
+
cur_language_model,
|
| 699 |
+
"--output-word-probs",
|
| 700 |
+
"--batch-size",
|
| 701 |
+
str(batch_size),
|
| 702 |
+
"--max-tokens",
|
| 703 |
+
"1024",
|
| 704 |
+
"--sample-break-mode",
|
| 705 |
+
"eos",
|
| 706 |
+
"--gen-subset",
|
| 707 |
+
"train",
|
| 708 |
+
]
|
| 709 |
+
|
| 710 |
+
eval_lm_parser = options.get_eval_lm_parser()
|
| 711 |
+
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
| 712 |
+
|
| 713 |
+
with open(lm_score_file, "w") as f:
|
| 714 |
+
with redirect_stdout(f):
|
| 715 |
+
eval_lm.main(input_args)
|
| 716 |
+
|
| 717 |
+
elif bpe_status == "shared":
|
| 718 |
+
preprocess_lm_param = [
|
| 719 |
+
"--only-source",
|
| 720 |
+
"--trainpref",
|
| 721 |
+
pre_gen + "/rescore_data." + target_lang,
|
| 722 |
+
"--srcdict",
|
| 723 |
+
cur_lm_dict,
|
| 724 |
+
"--destdir",
|
| 725 |
+
preprocess_directory,
|
| 726 |
+
]
|
| 727 |
+
preprocess_parser = options.get_preprocessing_parser()
|
| 728 |
+
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
| 729 |
+
preprocess.main(input_args)
|
| 730 |
+
|
| 731 |
+
eval_lm_param = [
|
| 732 |
+
preprocess_directory,
|
| 733 |
+
"--path",
|
| 734 |
+
cur_language_model,
|
| 735 |
+
"--output-word-probs",
|
| 736 |
+
"--batch-size",
|
| 737 |
+
str(batch_size),
|
| 738 |
+
"--sample-break-mode",
|
| 739 |
+
"eos",
|
| 740 |
+
"--gen-subset",
|
| 741 |
+
"train",
|
| 742 |
+
]
|
| 743 |
+
|
| 744 |
+
eval_lm_parser = options.get_eval_lm_parser()
|
| 745 |
+
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
| 746 |
+
|
| 747 |
+
with open(lm_score_file, "w") as f:
|
| 748 |
+
with redirect_stdout(f):
|
| 749 |
+
eval_lm.main(input_args)
|
| 750 |
+
|
| 751 |
+
elif bpe_status == "different":
|
| 752 |
+
rescore_file = pre_gen + "/rescore_data_no_bpe"
|
| 753 |
+
rescore_bpe = pre_gen + "/rescore_data_new_bpe"
|
| 754 |
+
|
| 755 |
+
rescore_file += "."
|
| 756 |
+
rescore_bpe += "."
|
| 757 |
+
|
| 758 |
+
write_reprocessed(
|
| 759 |
+
gen_output.no_bpe_source,
|
| 760 |
+
gen_output.no_bpe_hypo,
|
| 761 |
+
gen_output.no_bpe_target,
|
| 762 |
+
rescore_file + source_lang,
|
| 763 |
+
rescore_file + target_lang,
|
| 764 |
+
pre_gen + "/reference_file_no_bpe",
|
| 765 |
+
bpe_symbol=None,
|
| 766 |
+
)
|
| 767 |
+
|
| 768 |
+
# apply LM bpe to nbest list
|
| 769 |
+
bpe_src_param = [
|
| 770 |
+
"-c",
|
| 771 |
+
cur_lm_bpe_code,
|
| 772 |
+
"--input",
|
| 773 |
+
rescore_file + target_lang,
|
| 774 |
+
"--output",
|
| 775 |
+
rescore_bpe + target_lang,
|
| 776 |
+
]
|
| 777 |
+
subprocess.call(
|
| 778 |
+
[
|
| 779 |
+
"python",
|
| 780 |
+
os.path.join(
|
| 781 |
+
os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
|
| 782 |
+
),
|
| 783 |
+
]
|
| 784 |
+
+ bpe_src_param,
|
| 785 |
+
shell=False,
|
| 786 |
+
)
|
| 787 |
+
# uncomment to use fastbpe instead of subword-nmt bpe
|
| 788 |
+
# bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
|
| 789 |
+
# subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
|
| 790 |
+
|
| 791 |
+
preprocess_dir = preprocess_directory
|
| 792 |
+
|
| 793 |
+
preprocess_lm_param = [
|
| 794 |
+
"--only-source",
|
| 795 |
+
"--trainpref",
|
| 796 |
+
rescore_bpe + target_lang,
|
| 797 |
+
"--srcdict",
|
| 798 |
+
cur_lm_dict,
|
| 799 |
+
"--destdir",
|
| 800 |
+
preprocess_dir,
|
| 801 |
+
]
|
| 802 |
+
preprocess_parser = options.get_preprocessing_parser()
|
| 803 |
+
input_args = preprocess_parser.parse_args(preprocess_lm_param)
|
| 804 |
+
preprocess.main(input_args)
|
| 805 |
+
|
| 806 |
+
eval_lm_param = [
|
| 807 |
+
preprocess_dir,
|
| 808 |
+
"--path",
|
| 809 |
+
cur_language_model,
|
| 810 |
+
"--output-word-probs",
|
| 811 |
+
"--batch-size",
|
| 812 |
+
str(batch_size),
|
| 813 |
+
"--max-tokens",
|
| 814 |
+
"1024",
|
| 815 |
+
"--sample-break-mode",
|
| 816 |
+
"eos",
|
| 817 |
+
"--gen-subset",
|
| 818 |
+
"train",
|
| 819 |
+
]
|
| 820 |
+
|
| 821 |
+
eval_lm_parser = options.get_eval_lm_parser()
|
| 822 |
+
input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
|
| 823 |
+
|
| 824 |
+
with open(lm_score_file, "w") as f:
|
| 825 |
+
with redirect_stdout(f):
|
| 826 |
+
eval_lm.main(input_args)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def rescore_file_name(
|
| 830 |
+
nbest_dir,
|
| 831 |
+
prefix_len,
|
| 832 |
+
scorer_name,
|
| 833 |
+
lm_file=False,
|
| 834 |
+
target_prefix_frac=None,
|
| 835 |
+
source_prefix_frac=None,
|
| 836 |
+
backwards=None,
|
| 837 |
+
):
|
| 838 |
+
if lm_file:
|
| 839 |
+
score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
|
| 840 |
+
else:
|
| 841 |
+
score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
|
| 842 |
+
if backwards:
|
| 843 |
+
if prefix_len is not None:
|
| 844 |
+
score_file += "prefix_len" + str(prefix_len)
|
| 845 |
+
elif target_prefix_frac is not None:
|
| 846 |
+
score_file += "target_prefix_frac" + str(target_prefix_frac)
|
| 847 |
+
else:
|
| 848 |
+
if source_prefix_frac is not None:
|
| 849 |
+
score_file += "source_prefix_frac" + str(source_prefix_frac)
|
| 850 |
+
return score_file
|
fairseq-0.10.2/examples/paraphraser/README.md
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Paraphrasing with round-trip translation and mixture of experts
|
| 2 |
+
|
| 3 |
+
Machine translation models can be used to paraphrase text by translating it to
|
| 4 |
+
an intermediate language and back (round-trip translation).
|
| 5 |
+
|
| 6 |
+
This example shows how to paraphrase text by first passing it to an
|
| 7 |
+
English-French translation model, followed by a French-English [mixture of
|
| 8 |
+
experts translation model](/examples/translation_moe).
|
| 9 |
+
|
| 10 |
+
##### 0. Setup
|
| 11 |
+
|
| 12 |
+
Clone fairseq from source and install necessary dependencies:
|
| 13 |
+
```bash
|
| 14 |
+
git clone https://github.com/pytorch/fairseq.git
|
| 15 |
+
cd fairseq
|
| 16 |
+
pip install --editable .
|
| 17 |
+
pip install sacremoses sentencepiece
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
##### 1. Download models
|
| 21 |
+
```bash
|
| 22 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.en-fr.tar.gz
|
| 23 |
+
wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.fr-en.hMoEup.tar.gz
|
| 24 |
+
tar -xzvf paraphraser.en-fr.tar.gz
|
| 25 |
+
tar -xzvf paraphraser.fr-en.hMoEup.tar.gz
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
##### 2. Paraphrase
|
| 29 |
+
```bash
|
| 30 |
+
python examples/paraphraser/paraphrase.py \
|
| 31 |
+
--en2fr paraphraser.en-fr \
|
| 32 |
+
--fr2en paraphraser.fr-en.hMoEup
|
| 33 |
+
# Example input:
|
| 34 |
+
# The new date for the Games, postponed for a year in response to the coronavirus pandemic, gives athletes time to recalibrate their training schedules.
|
| 35 |
+
# Example outputs:
|
| 36 |
+
# Delayed one year in response to the coronavirus pandemic, the new date of the Games gives athletes time to rebalance their training schedule.
|
| 37 |
+
# The new date of the Games, which was rescheduled one year in response to the coronavirus (CV) pandemic, gives athletes time to rebalance their training schedule.
|
| 38 |
+
# The new date of the Games, postponed one year in response to the coronavirus pandemic, provides athletes with time to rebalance their training schedule.
|
| 39 |
+
# The Games' new date, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
|
| 40 |
+
# The new Games date, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
|
| 41 |
+
# The new date of the Games, which was postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
|
| 42 |
+
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
|
| 43 |
+
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to re-balance their training schedule.
|
| 44 |
+
# The new date of the Games, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their schedule of training.
|
| 45 |
+
# The new date of the Games, postponed one year in response to the pandemic of coronavirus, gives the athletes time to rebalance their training schedule.
|
| 46 |
+
```
|
fairseq-0.10.2/examples/paraphraser/paraphrase.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3 -u
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import fileinput
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from fairseq.models.transformer import TransformerModel
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
parser = argparse.ArgumentParser(description="")
|
| 17 |
+
parser.add_argument("--en2fr", required=True, help="path to en2fr model")
|
| 18 |
+
parser.add_argument(
|
| 19 |
+
"--fr2en", required=True, help="path to fr2en mixture of experts model"
|
| 20 |
+
)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--user-dir", help="path to fairseq examples/translation_moe/src directory"
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--num-experts",
|
| 26 |
+
type=int,
|
| 27 |
+
default=10,
|
| 28 |
+
help="(keep at 10 unless using a different model)",
|
| 29 |
+
)
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"files",
|
| 32 |
+
nargs="*",
|
| 33 |
+
default=["-"],
|
| 34 |
+
help='input files to paraphrase; "-" for stdin',
|
| 35 |
+
)
|
| 36 |
+
args = parser.parse_args()
|
| 37 |
+
|
| 38 |
+
if args.user_dir is None:
|
| 39 |
+
args.user_dir = os.path.join(
|
| 40 |
+
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
|
| 41 |
+
"translation_moe",
|
| 42 |
+
"src",
|
| 43 |
+
)
|
| 44 |
+
if os.path.exists(args.user_dir):
|
| 45 |
+
logging.info("found user_dir:" + args.user_dir)
|
| 46 |
+
else:
|
| 47 |
+
raise RuntimeError(
|
| 48 |
+
"cannot find fairseq examples/translation_moe/src "
|
| 49 |
+
"(tried looking here: {})".format(args.user_dir)
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
logging.info("loading en2fr model from:" + args.en2fr)
|
| 53 |
+
en2fr = TransformerModel.from_pretrained(
|
| 54 |
+
model_name_or_path=args.en2fr,
|
| 55 |
+
tokenizer="moses",
|
| 56 |
+
bpe="sentencepiece",
|
| 57 |
+
).eval()
|
| 58 |
+
|
| 59 |
+
logging.info("loading fr2en model from:" + args.fr2en)
|
| 60 |
+
fr2en = TransformerModel.from_pretrained(
|
| 61 |
+
model_name_or_path=args.fr2en,
|
| 62 |
+
tokenizer="moses",
|
| 63 |
+
bpe="sentencepiece",
|
| 64 |
+
user_dir=args.user_dir,
|
| 65 |
+
task="translation_moe",
|
| 66 |
+
).eval()
|
| 67 |
+
|
| 68 |
+
def gen_paraphrases(en):
|
| 69 |
+
fr = en2fr.translate(en)
|
| 70 |
+
return [
|
| 71 |
+
fr2en.translate(fr, inference_step_args={"expert": i})
|
| 72 |
+
for i in range(args.num_experts)
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
logging.info("Type the input sentence and press return:")
|
| 76 |
+
for line in fileinput.input(args.files):
|
| 77 |
+
line = line.strip()
|
| 78 |
+
if len(line) == 0:
|
| 79 |
+
continue
|
| 80 |
+
for paraphrase in gen_paraphrases(line):
|
| 81 |
+
print(paraphrase)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if __name__ == "__main__":
|
| 85 |
+
main()
|
fairseq-0.10.2/examples/simultaneous_translation/README.md
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Simultaneous Machine Translation
|
| 2 |
+
|
| 3 |
+
This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS)
|
| 4 |
+
|
| 5 |
+
## Prepare Data
|
| 6 |
+
|
| 7 |
+
[Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh)
|
| 8 |
+
|
| 9 |
+
## Training
|
| 10 |
+
|
| 11 |
+
- MMA-IL
|
| 12 |
+
|
| 13 |
+
```shell
|
| 14 |
+
fairseq-train \
|
| 15 |
+
data-bin/wmt15_en_de_32k \
|
| 16 |
+
--simul-type infinite_lookback \
|
| 17 |
+
--user-dir $FAIRSEQ/example/simultaneous_translation \
|
| 18 |
+
--mass-preservation \
|
| 19 |
+
--criterion latency_augmented_label_smoothed_cross_entropy \
|
| 20 |
+
--latency-weight-avg 0.1 \
|
| 21 |
+
--max-update 50000 \
|
| 22 |
+
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
|
| 23 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' \
|
| 24 |
+
--lr-scheduler 'inverse_sqrt' \
|
| 25 |
+
--warmup-init-lr 1e-7 --warmup-updates 4000 \
|
| 26 |
+
--lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
|
| 27 |
+
--dropout 0.3 \
|
| 28 |
+
--label-smoothing 0.1\
|
| 29 |
+
--max-tokens 3584
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
- MMA-H
|
| 33 |
+
|
| 34 |
+
```shell
|
| 35 |
+
fairseq-train \
|
| 36 |
+
data-bin/wmt15_en_de_32k \
|
| 37 |
+
--simul-type hard_aligned \
|
| 38 |
+
--user-dir $FAIRSEQ/example/simultaneous_translation \
|
| 39 |
+
--mass-preservation \
|
| 40 |
+
--criterion latency_augmented_label_smoothed_cross_entropy \
|
| 41 |
+
--latency-weight-var 0.1 \
|
| 42 |
+
--max-update 50000 \
|
| 43 |
+
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
|
| 44 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' \
|
| 45 |
+
--lr-scheduler 'inverse_sqrt' \
|
| 46 |
+
--warmup-init-lr 1e-7 --warmup-updates 4000 \
|
| 47 |
+
--lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
|
| 48 |
+
--dropout 0.3 \
|
| 49 |
+
--label-smoothing 0.1\
|
| 50 |
+
--max-tokens 3584
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
- wait-k
|
| 54 |
+
|
| 55 |
+
```shell
|
| 56 |
+
fairseq-train \
|
| 57 |
+
data-bin/wmt15_en_de_32k \
|
| 58 |
+
--simul-type wait-k \
|
| 59 |
+
--waitk-lagging 3 \
|
| 60 |
+
--user-dir $FAIRSEQ/example/simultaneous_translation \
|
| 61 |
+
--mass-preservation \
|
| 62 |
+
--criterion latency_augmented_label_smoothed_cross_entropy \
|
| 63 |
+
--max-update 50000 \
|
| 64 |
+
--arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
|
| 65 |
+
--optimizer adam --adam-betas '(0.9, 0.98)' \
|
| 66 |
+
--lr-scheduler 'inverse_sqrt' \
|
| 67 |
+
--warmup-init-lr 1e-7 --warmup-updates 4000 \
|
| 68 |
+
--lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
|
| 69 |
+
--dropout 0.3 \
|
| 70 |
+
--label-smoothing 0.1\
|
| 71 |
+
--max-tokens 3584
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
## Evaluation
|
| 76 |
+
|
| 77 |
+
More details on evaluation can be found [here](https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/simultaneous_translation/docs/evaluation.md)
|
| 78 |
+
|
| 79 |
+
### Start the server
|
| 80 |
+
|
| 81 |
+
```shell
|
| 82 |
+
python ./eval/server.py \
|
| 83 |
+
--src-file $SRC_FILE \
|
| 84 |
+
--ref-file $TGT_FILE
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Run the client
|
| 88 |
+
|
| 89 |
+
```shell
|
| 90 |
+
python ./evaluate.py \
|
| 91 |
+
--data-bin data-bin/wmt15_en_de_32k \
|
| 92 |
+
--model-path ./checkpoints/checkpoint_best.pt
|
| 93 |
+
--scores --output $RESULT_DIR
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Run evaluation locally without server
|
| 97 |
+
|
| 98 |
+
```shell
|
| 99 |
+
python ./eval/evaluate.py
|
| 100 |
+
--local \
|
| 101 |
+
--src-file $SRC_FILE \
|
| 102 |
+
--tgt-file $TGT_FILE \
|
| 103 |
+
--data-bin data-bin/wmt15_en_de_32k \
|
| 104 |
+
--model-path ./checkpoints/checkpoint_best.pt \
|
| 105 |
+
--scores --output $RESULT_DIR
|
| 106 |
+
```
|
fairseq-0.10.2/examples/simultaneous_translation/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from . import criterions, eval, models # noqa
|
fairseq-0.10.2/examples/simultaneous_translation/docs/baseline.md
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# **Baseline Simultaneous Translation**
|
| 2 |
+
---
|
| 3 |
+
|
| 4 |
+
This is an instruction of training and evaluating a *wait-k* simultanoes LSTM model on MUST-C English-Gernam Dataset.
|
| 5 |
+
|
| 6 |
+
[STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework](https://https://www.aclweb.org/anthology/P19-1289/)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
## **Requirements**
|
| 10 |
+
Install fairseq (make sure to use the correct branch):
|
| 11 |
+
```
|
| 12 |
+
git clone --branch simulastsharedtask git@github.com:pytorch/fairseq.git
|
| 13 |
+
cd fairseq
|
| 14 |
+
pip install -e .
|
| 15 |
+
```
|
| 16 |
+
|
| 17 |
+
Assuming that fairseq is installed in a directory called `FAIRSEQ`.
|
| 18 |
+
|
| 19 |
+
Install SentencePiece. One easy way is to use anaconda:
|
| 20 |
+
|
| 21 |
+
```
|
| 22 |
+
conda install -c powerai sentencepiece
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Download the MuST-C data for English-German available at https://ict.fbk.eu/must-c/.
|
| 26 |
+
We will assume that the data is downloaded in a directory called `DATA_ROOT`.
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## **Text-to-text Model**
|
| 30 |
+
---
|
| 31 |
+
### Data Preparation
|
| 32 |
+
Train a SentencePiece model:
|
| 33 |
+
```shell
|
| 34 |
+
for lang in en de; do
|
| 35 |
+
python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \
|
| 36 |
+
--data-path $DATA_ROOT/data \
|
| 37 |
+
--vocab-size 10000 \
|
| 38 |
+
--max-frame 3000 \
|
| 39 |
+
--model-type unigram \
|
| 40 |
+
--lang $lang \
|
| 41 |
+
--out-path .
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
Process the data with the SentencePiece model:
|
| 45 |
+
```shell
|
| 46 |
+
proc_dir=proc
|
| 47 |
+
mkdir -p $proc_dir
|
| 48 |
+
for split in train dev tst-COMMON tst-HE; do
|
| 49 |
+
for lang in en de; do
|
| 50 |
+
spm_encode \
|
| 51 |
+
--model unigram-$lang-10000-3000/spm.model \
|
| 52 |
+
< $DATA_ROOT/data/$split/txt/$split.$lang \
|
| 53 |
+
> $proc_dir/$split.spm.$lang
|
| 54 |
+
done
|
| 55 |
+
done
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Binarize the data:
|
| 59 |
+
|
| 60 |
+
```shell
|
| 61 |
+
proc_dir=proc
|
| 62 |
+
fairseq-preprocess \
|
| 63 |
+
--source-lang en --target-lang de \
|
| 64 |
+
--trainpref $proc_dir/train.spm \
|
| 65 |
+
--validpref $proc_dir/dev.spm \
|
| 66 |
+
--testpref $proc_dir/tst-COMMON.spm \
|
| 67 |
+
--thresholdtgt 0 \
|
| 68 |
+
--thresholdsrc 0 \
|
| 69 |
+
--workers 20 \
|
| 70 |
+
--destdir ./data-bin/mustc_en_de \
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Training
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
```shell
|
| 77 |
+
mkdir -p checkpoints
|
| 78 |
+
CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \
|
| 79 |
+
--save-dir checkpoints \
|
| 80 |
+
--arch berard_simul_text_iwslt \
|
| 81 |
+
--simul-type waitk \
|
| 82 |
+
--waitk-lagging 2 \
|
| 83 |
+
--optimizer adam \
|
| 84 |
+
--max-epoch 100 \
|
| 85 |
+
--lr 0.001 \
|
| 86 |
+
--clip-norm 5.0 \
|
| 87 |
+
--batch-size 128 \
|
| 88 |
+
--log-format json \
|
| 89 |
+
--log-interval 10 \
|
| 90 |
+
--criterion cross_entropy_acc \
|
| 91 |
+
--user-dir $FAIRSEQ/examples/simultaneous_translation
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## **Speech-to-text Model**
|
| 95 |
+
---
|
| 96 |
+
### Data Preparation
|
| 97 |
+
First, segment wav files.
|
| 98 |
+
```shell
|
| 99 |
+
python $FAIRSEQ/examples/simultaneous_translation/data/segment_wav.py \
|
| 100 |
+
--datapath $DATA_ROOT
|
| 101 |
+
```
|
| 102 |
+
Similar to text-to-text model, train a Sentencepiecemodel, but only train on German
|
| 103 |
+
```Shell
|
| 104 |
+
python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \
|
| 105 |
+
--data-path $DATA_ROOT/data \
|
| 106 |
+
--vocab-size 10000 \
|
| 107 |
+
--max-frame 3000 \
|
| 108 |
+
--model-type unigram \
|
| 109 |
+
--lang $lang \
|
| 110 |
+
--out-path .
|
| 111 |
+
```
|
| 112 |
+
## Training
|
| 113 |
+
```shell
|
| 114 |
+
mkdir -p checkpoints
|
| 115 |
+
CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \
|
| 116 |
+
--save-dir checkpoints \
|
| 117 |
+
--arch berard_simul_text_iwslt \
|
| 118 |
+
--waitk-lagging 2 \
|
| 119 |
+
--waitk-stride 10 \
|
| 120 |
+
--input-feat-per-channel 40 \
|
| 121 |
+
--encoder-hidden-size 512 \
|
| 122 |
+
--output-layer-dim 128 \
|
| 123 |
+
--decoder-num-layers 3 \
|
| 124 |
+
--task speech_translation \
|
| 125 |
+
--user-dir $FAIRSEQ/examples/simultaneous_translation
|
| 126 |
+
--optimizer adam \
|
| 127 |
+
--max-epoch 100 \
|
| 128 |
+
--lr 0.001 \
|
| 129 |
+
--clip-norm 5.0 \
|
| 130 |
+
--batch-size 128 \
|
| 131 |
+
--log-format json \
|
| 132 |
+
--log-interval 10 \
|
| 133 |
+
--criterion cross_entropy_acc \
|
| 134 |
+
--user-dir $FAIRSEQ/examples/simultaneous_translation
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
## Evaluation
|
| 138 |
+
---
|
| 139 |
+
### Evaluation Server
|
| 140 |
+
For text translation models, the server is set up as follow give input file and reference file.
|
| 141 |
+
|
| 142 |
+
``` shell
|
| 143 |
+
python ./eval/server.py \
|
| 144 |
+
--hostname localhost \
|
| 145 |
+
--port 12321 \
|
| 146 |
+
--src-file $DATA_ROOT/data/dev/txt/dev.en \
|
| 147 |
+
--ref-file $DATA_ROOT/data/dev/txt/dev.de
|
| 148 |
+
```
|
| 149 |
+
For speech translation models, the input is the data direcrory.
|
| 150 |
+
``` shell
|
| 151 |
+
python ./eval/server.py \
|
| 152 |
+
--hostname localhost \
|
| 153 |
+
--port 12321 \
|
| 154 |
+
--ref-file $DATA_ROOT \
|
| 155 |
+
--data-type speech
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Decode and Evaluate with Client
|
| 159 |
+
Once the server is set up, run client to evaluate translation quality and latency.
|
| 160 |
+
```shell
|
| 161 |
+
# TEXT
|
| 162 |
+
python $fairseq_dir/examples/simultaneous_translation/evaluate.py \
|
| 163 |
+
data-bin/mustc_en_de \
|
| 164 |
+
--user-dir $FAIRSEQ/examples/simultaneous_translation \
|
| 165 |
+
--src-spm unigram-en-10000-3000/spm.model\
|
| 166 |
+
--tgt-spm unigram-de-10000-3000/spm.model\
|
| 167 |
+
-s en -t de \
|
| 168 |
+
--path checkpoints/checkpoint_best.pt
|
| 169 |
+
|
| 170 |
+
# SPEECH
|
| 171 |
+
python $fairseq_dir/examples/simultaneous_translation/evaluate.py \
|
| 172 |
+
data-bin/mustc_en_de \
|
| 173 |
+
--user-dir $FAIRSEQ/examples/simultaneous_translation \
|
| 174 |
+
--data-type speech \
|
| 175 |
+
--tgt-spm unigram-de-10000-3000/spm.model\
|
| 176 |
+
-s en -t de \
|
| 177 |
+
--path checkpoints/checkpoint_best.pt
|
| 178 |
+
```
|
fairseq-0.10.2/examples/simultaneous_translation/docs/evaluation.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Introduction to evaluation interface
|
| 2 |
+
The simultaneous translation models from sharedtask participents are evaluated under a server-client protocol. The participents are requisted to plug in their own model API in the protocol, and submit a docker file.
|
| 3 |
+
|
| 4 |
+
## Server-Client Protocol
|
| 5 |
+
An server-client protocol that will be used in evaluation. For example, when a *wait-k* model (k=3) translate the English sentence "Alice and Bob are good friends" to Genman sentence "Alice und Bob sind gute Freunde." , the evaluation process is shown as following figure.
|
| 6 |
+
|
| 7 |
+
While every time client needs to read a new state (word or speech utterence), a "GET" request is supposed to sent over to server. Whenever a new token is generated, a "SEND" request with the word predicted (untokenized word) will be sent to server immediately. The server can hence calculate both latency and BLEU score of the sentence.
|
| 8 |
+
|
| 9 |
+
### Server
|
| 10 |
+
The server code is provided and can be set up directly locally for development purpose. For example, to evaluate a text simultaneous test set,
|
| 11 |
+
|
| 12 |
+
```shell
|
| 13 |
+
|
| 14 |
+
python fairseq/examples/simultaneous_translation/eval/server.py \
|
| 15 |
+
--hostname local_host \
|
| 16 |
+
--port 1234 \
|
| 17 |
+
--src-file SRC_FILE \
|
| 18 |
+
--ref-file REF_FILE \
|
| 19 |
+
--data-type text \
|
| 20 |
+
```
|
| 21 |
+
The state that server sent to client is has the following format
|
| 22 |
+
```json
|
| 23 |
+
{
|
| 24 |
+
'sent_id': Int,
|
| 25 |
+
'segment_id': Int,
|
| 26 |
+
'segment': String
|
| 27 |
+
}
|
| 28 |
+
```
|
| 29 |
+
|
| 30 |
+
### Client
|
| 31 |
+
The client will handle the evaluation process mentioned above. It should be out-of-box as well. The client's protocol is as following table
|
| 32 |
+
|
| 33 |
+
|Action|Content|
|
| 34 |
+
|:---:|:---:|
|
| 35 |
+
|Request new word / utterence| ```{key: "Get", value: None}```|
|
| 36 |
+
|Predict word "W"| ```{key: "SEND", value: "W"}```|
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
The core of the client module is the agent, which needs to be modified to different models accordingly. The abstract class of agent is as follow, the evaluation process happens in the `decode()` function.
|
| 41 |
+
```python
|
| 42 |
+
class Agent(object):
|
| 43 |
+
"an agent needs to follow this pattern"
|
| 44 |
+
def __init__(self, *args, **kwargs):
|
| 45 |
+
...
|
| 46 |
+
|
| 47 |
+
def init_states(self):
|
| 48 |
+
# Initializing states
|
| 49 |
+
...
|
| 50 |
+
|
| 51 |
+
def update_states(self, states, new_state):
|
| 52 |
+
# Update states with given new state from server
|
| 53 |
+
# TODO (describe the states)
|
| 54 |
+
...
|
| 55 |
+
|
| 56 |
+
def finish_eval(self, states, new_state):
|
| 57 |
+
# Check if evaluation is finished
|
| 58 |
+
...
|
| 59 |
+
|
| 60 |
+
def policy(self, state: list) -> dict:
|
| 61 |
+
# Provide a action given current states
|
| 62 |
+
# The action can only be either
|
| 63 |
+
# {key: "GET", value: NONE}
|
| 64 |
+
# or
|
| 65 |
+
# {key: "SEND", value: W}
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
def reset(self):
|
| 69 |
+
# Reset agent
|
| 70 |
+
...
|
| 71 |
+
|
| 72 |
+
def decode(self, session):
|
| 73 |
+
|
| 74 |
+
states = self.init_states()
|
| 75 |
+
self.reset()
|
| 76 |
+
|
| 77 |
+
# Evaluataion protocol happens here
|
| 78 |
+
while True:
|
| 79 |
+
# Get action from the current states according to self.policy()
|
| 80 |
+
action = self.policy(states)
|
| 81 |
+
|
| 82 |
+
if action['key'] == GET:
|
| 83 |
+
# Read a new state from server
|
| 84 |
+
new_state = session.get_src()
|
| 85 |
+
states = self.update_states(states, new_state)
|
| 86 |
+
|
| 87 |
+
if self.finish_eval(states, new_state):
|
| 88 |
+
# End of document
|
| 89 |
+
break
|
| 90 |
+
|
| 91 |
+
elif action['key'] == SEND:
|
| 92 |
+
# Send a new prediction to server
|
| 93 |
+
session.send_hypo(action['value'])
|
| 94 |
+
|
| 95 |
+
# Clean the history, wait for next sentence
|
| 96 |
+
if action['value'] == DEFAULT_EOS:
|
| 97 |
+
states = self.init_states()
|
| 98 |
+
self.reset()
|
| 99 |
+
else:
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
```
|
| 104 |
+
Here an implementation of agent of text [*wait-k* model](somelink). Notice that the tokenization is not considered.
|
| 105 |
+
|
| 106 |
+
## Quality
|
| 107 |
+
The quality is measured by detokenized BLEU. So make sure that the predicted words sent to server are detokenized. An implementation is can be find [here](some link)
|
| 108 |
+
|
| 109 |
+
## Latency
|
| 110 |
+
The latency metrics are
|
| 111 |
+
* Average Proportion
|
| 112 |
+
* Average Lagging
|
| 113 |
+
* Differentiable Average Lagging
|
| 114 |
+
Again Thery will also be evaluated on detokenized text.
|
| 115 |
+
|
fairseq-0.10.2/examples/simultaneous_translation/eval/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
fairseq-0.10.2/examples/simultaneous_translation/eval/agents/word_splitter.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class SubwordSplitter(object):
|
| 8 |
+
def process_line(self, string):
|
| 9 |
+
raise NotImplementedError
|
| 10 |
+
|
| 11 |
+
def split(self, string):
|
| 12 |
+
raise NotImplementedError
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class NoneWordSplitter(object):
|
| 16 |
+
def __init__(self, model):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def split(self, string):
|
| 20 |
+
return [string]
|
| 21 |
+
|
| 22 |
+
def process_line(self, string):
|
| 23 |
+
return [string]
|
| 24 |
+
|
| 25 |
+
def finished_word(self, string):
|
| 26 |
+
return True
|
| 27 |
+
|
| 28 |
+
def merge(self, list_of_string):
|
| 29 |
+
return "".join(list_of_string)
|
| 30 |
+
|
| 31 |
+
def last_full_word_step(self, tokens, step):
|
| 32 |
+
return len(tokens)
|
| 33 |
+
|
| 34 |
+
def end_idx_last_full_word(self, tokens):
|
| 35 |
+
return len(tokens)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class BPEWordSplitter(object):
|
| 39 |
+
# TODO: lock back here
|
| 40 |
+
def __init__(self, model_path):
|
| 41 |
+
super().__init__()
|
| 42 |
+
from subword_nmt.apply_bpe import BPE
|
| 43 |
+
|
| 44 |
+
with open(model_path) as f:
|
| 45 |
+
self.model = BPE(f)
|
| 46 |
+
|
| 47 |
+
def split(self, string):
|
| 48 |
+
return self.model.process_line(string).split()
|
| 49 |
+
|
| 50 |
+
def end_idx_last_full_word(self, tokens):
|
| 51 |
+
# Begin of word indices
|
| 52 |
+
bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
|
| 53 |
+
|
| 54 |
+
if len(bow_indices) < 2:
|
| 55 |
+
return 0
|
| 56 |
+
else:
|
| 57 |
+
return bow_indices[-1]
|
| 58 |
+
|
| 59 |
+
def merge(self, list_of_string):
|
| 60 |
+
return " ".join([item.replace("@@", "") for item in list_of_string])
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class SentencePieceModelWordSplitter(object):
|
| 64 |
+
def __init__(self, model_path):
|
| 65 |
+
super().__init__()
|
| 66 |
+
import sentencepiece as spm
|
| 67 |
+
|
| 68 |
+
self.model = spm.SentencePieceProcessor()
|
| 69 |
+
self.model.Load(model_path)
|
| 70 |
+
|
| 71 |
+
def split(self, string):
|
| 72 |
+
return self.model.EncodeAsPieces(string)
|
| 73 |
+
|
| 74 |
+
def end_idx_last_full_word(self, tokens):
|
| 75 |
+
# Begin of word indices
|
| 76 |
+
bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
|
| 77 |
+
|
| 78 |
+
if len(bow_indices) < 2:
|
| 79 |
+
return 0
|
| 80 |
+
else:
|
| 81 |
+
return bow_indices[-1]
|
| 82 |
+
|
| 83 |
+
def merge(self, list_of_string):
|
| 84 |
+
return self.model.DecodePieces(list_of_string)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
SPLITTER_DICT = {
|
| 88 |
+
None: NoneWordSplitter,
|
| 89 |
+
"BPE": BPEWordSplitter,
|
| 90 |
+
"SentencePieceModel": SentencePieceModelWordSplitter,
|
| 91 |
+
}
|
fairseq-0.10.2/examples/simultaneous_translation/eval/client.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
import requests
|
| 9 |
+
from scorers import build_scorer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SimulSTEvaluationService(object):
|
| 13 |
+
DEFAULT_HOSTNAME = "localhost"
|
| 14 |
+
DEFAULT_PORT = 12321
|
| 15 |
+
|
| 16 |
+
def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
|
| 17 |
+
self.hostname = hostname
|
| 18 |
+
self.port = port
|
| 19 |
+
self.base_url = f"http://{self.hostname}:{self.port}"
|
| 20 |
+
|
| 21 |
+
def __enter__(self):
|
| 22 |
+
self.new_session()
|
| 23 |
+
|
| 24 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
def new_session(self):
|
| 28 |
+
# start eval session
|
| 29 |
+
url = f"{self.base_url}"
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
_ = requests.post(url)
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Failed to start an evaluation session: {e}")
|
| 35 |
+
|
| 36 |
+
print("Evaluation session started.")
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def get_scores(self):
|
| 40 |
+
# end eval session
|
| 41 |
+
url = f"{self.base_url}/result"
|
| 42 |
+
try:
|
| 43 |
+
r = requests.get(url)
|
| 44 |
+
print("Scores: {}".format(r.json()))
|
| 45 |
+
print("Evaluation session finished.")
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"Failed to end an evaluation session: {e}")
|
| 48 |
+
|
| 49 |
+
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
|
| 50 |
+
url = f"{self.base_url}/src"
|
| 51 |
+
params = {"sent_id": sent_id}
|
| 52 |
+
if extra_params is not None:
|
| 53 |
+
for key in extra_params.keys():
|
| 54 |
+
params[key] = extra_params[key]
|
| 55 |
+
try:
|
| 56 |
+
r = requests.get(url, params=params)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
print(f"Failed to request a source segment: {e}")
|
| 59 |
+
return r.json()
|
| 60 |
+
|
| 61 |
+
def send_hypo(self, sent_id: int, hypo: str) -> None:
|
| 62 |
+
url = f"{self.base_url}/hypo"
|
| 63 |
+
params = {"sent_id": sent_id}
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
requests.put(url, params=params, data=hypo.encode("utf-8"))
|
| 67 |
+
except Exception as e:
|
| 68 |
+
print(f"Failed to send a translated segment: {e}")
|
| 69 |
+
|
| 70 |
+
def corpus_info(self):
|
| 71 |
+
url = f"{self.base_url}"
|
| 72 |
+
try:
|
| 73 |
+
r = requests.get(url)
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Failed to request corpus information: {e}")
|
| 76 |
+
|
| 77 |
+
return r.json()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class SimulSTLocalEvaluationService(object):
|
| 81 |
+
def __init__(self, args):
|
| 82 |
+
self.scorer = build_scorer(args)
|
| 83 |
+
|
| 84 |
+
def get_scores(self):
|
| 85 |
+
return self.scorer.score()
|
| 86 |
+
|
| 87 |
+
def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
|
| 88 |
+
if extra_params is not None:
|
| 89 |
+
segment_size = extra_params.get("segment_size", None)
|
| 90 |
+
else:
|
| 91 |
+
segment_size = None
|
| 92 |
+
|
| 93 |
+
return self.scorer.send_src(int(sent_id), segment_size)
|
| 94 |
+
|
| 95 |
+
def send_hypo(self, sent_id: int, hypo: str) -> None:
|
| 96 |
+
list_of_tokens = hypo.strip().split()
|
| 97 |
+
self.scorer.recv_hyp(sent_id, list_of_tokens)
|
| 98 |
+
|
| 99 |
+
def corpus_info(self):
|
| 100 |
+
return self.scorer.get_info()
|
fairseq-0.10.2/examples/simultaneous_translation/eval/eval_latency.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
from examples.simultaneous_translation.utils.latency import LatencyInference
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
LATENCY_METRICS = [
|
| 14 |
+
"differentiable_average_lagging",
|
| 15 |
+
"average_lagging",
|
| 16 |
+
"average_proportion",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class LatencyScorer:
|
| 21 |
+
def __init__(self, start_from_zero=True):
|
| 22 |
+
self.recorder = []
|
| 23 |
+
self.scores = {}
|
| 24 |
+
self.scorer = LatencyInference()
|
| 25 |
+
self.start_from_zero = start_from_zero
|
| 26 |
+
|
| 27 |
+
def update_reorder(self, list_of_dict):
|
| 28 |
+
self.recorder = []
|
| 29 |
+
for info in list_of_dict:
|
| 30 |
+
delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
|
| 31 |
+
delays = torch.LongTensor(delays).unsqueeze(0)
|
| 32 |
+
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
|
| 33 |
+
|
| 34 |
+
self.recorder.append(self.scorer(delays, src_len))
|
| 35 |
+
|
| 36 |
+
def cal_latency(self):
|
| 37 |
+
self.scores = {}
|
| 38 |
+
for metric in LATENCY_METRICS:
|
| 39 |
+
self.scores[metric] = sum(
|
| 40 |
+
[x[metric][0, 0].item() for x in self.recorder]
|
| 41 |
+
) / len(self.recorder)
|
| 42 |
+
return self.scores
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def score(cls, list_of_dict, start_from_zero=True):
|
| 46 |
+
scorer_to_return = cls(start_from_zero)
|
| 47 |
+
scorer_to_return.update_reorder(list_of_dict)
|
| 48 |
+
scorer_to_return.cal_latency()
|
| 49 |
+
return scorer_to_return.scores
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
parser = argparse.ArgumentParser()
|
| 54 |
+
parser.add_argument("--input", required=True)
|
| 55 |
+
parser.add_argument("--start-from-zero", action="store_true")
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
|
| 58 |
+
scorer = LatencyInference()
|
| 59 |
+
recorder = []
|
| 60 |
+
with open(args.input, "r") as f:
|
| 61 |
+
for line in f:
|
| 62 |
+
info = json.loads(line)
|
| 63 |
+
|
| 64 |
+
delays = [int(x) - int(not args.start_from_zero) for x in info["delays"]]
|
| 65 |
+
|
| 66 |
+
delays = torch.LongTensor(delays).unsqueeze(0)
|
| 67 |
+
|
| 68 |
+
src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
|
| 69 |
+
|
| 70 |
+
recorder.append(scorer(delays, src_len))
|
| 71 |
+
|
| 72 |
+
average_results = {}
|
| 73 |
+
|
| 74 |
+
for metric in LATENCY_METRICS:
|
| 75 |
+
average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
|
| 76 |
+
recorder
|
| 77 |
+
)
|
| 78 |
+
print(f"{metric}: {average_results[metric]}")
|
fairseq-0.10.2/examples/simultaneous_translation/eval/evaluate.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
|
| 8 |
+
from agents import build_agent
|
| 9 |
+
from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
|
| 10 |
+
from fairseq.registry import REGISTRIES
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DEFAULT_HOSTNAME = "localhost"
|
| 14 |
+
DEFAULT_PORT = 12321
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_args():
|
| 18 |
+
parser = argparse.ArgumentParser()
|
| 19 |
+
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--port", type=int, default=DEFAULT_PORT, help="server port number"
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
|
| 27 |
+
parser.add_argument("--scorer-type", default="text", help="Scorer type")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--start-idx",
|
| 30 |
+
type=int,
|
| 31 |
+
default=0,
|
| 32 |
+
help="Start index of the sentence to evaluate",
|
| 33 |
+
)
|
| 34 |
+
parser.add_argument(
|
| 35 |
+
"--end-idx",
|
| 36 |
+
type=int,
|
| 37 |
+
default=float("inf"),
|
| 38 |
+
help="End index of the sentence to evaluate",
|
| 39 |
+
)
|
| 40 |
+
parser.add_argument(
|
| 41 |
+
"--scores", action="store_true", help="Request scores from server"
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument("--reset-server", action="store_true", help="Reset the server")
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
"--num-threads", type=int, default=10, help="Number of threads used by agent"
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--local", action="store_true", default=False, help="Local evaluation"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
args, _ = parser.parse_known_args()
|
| 52 |
+
|
| 53 |
+
for registry_name, REGISTRY in REGISTRIES.items():
|
| 54 |
+
choice = getattr(args, registry_name, None)
|
| 55 |
+
if choice is not None:
|
| 56 |
+
cls = REGISTRY["registry"][choice]
|
| 57 |
+
if hasattr(cls, "add_args"):
|
| 58 |
+
cls.add_args(parser)
|
| 59 |
+
args = parser.parse_args()
|
| 60 |
+
|
| 61 |
+
return args
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
args = get_args()
|
| 66 |
+
|
| 67 |
+
if args.local:
|
| 68 |
+
session = SimulSTLocalEvaluationService(args)
|
| 69 |
+
else:
|
| 70 |
+
session = SimulSTEvaluationService(args.hostname, args.port)
|
| 71 |
+
|
| 72 |
+
if args.reset_server:
|
| 73 |
+
session.new_session()
|
| 74 |
+
|
| 75 |
+
if args.agent_type is not None:
|
| 76 |
+
agent = build_agent(args)
|
| 77 |
+
agent.decode(session, args.start_idx, args.end_idx, args.num_threads)
|
| 78 |
+
|
| 79 |
+
if args.scores:
|
| 80 |
+
session.get_scores()
|
| 81 |
+
print(session.get_scores())
|
fairseq-0.10.2/examples/simultaneous_translation/eval/server.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import sys
|
| 8 |
+
|
| 9 |
+
from scorers import build_scorer
|
| 10 |
+
from tornado import ioloop, web
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DEFAULT_HOSTNAME = "localhost"
|
| 14 |
+
DEFAULT_PORT = 12321
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ScorerHandler(web.RequestHandler):
|
| 18 |
+
def initialize(self, scorer):
|
| 19 |
+
self.scorer = scorer
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class EvalSessionHandler(ScorerHandler):
|
| 23 |
+
def post(self):
|
| 24 |
+
self.scorer.reset()
|
| 25 |
+
|
| 26 |
+
def get(self):
|
| 27 |
+
r = json.dumps(self.scorer.get_info())
|
| 28 |
+
self.write(r)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class ResultHandler(ScorerHandler):
|
| 32 |
+
def get(self):
|
| 33 |
+
r = json.dumps(self.scorer.score())
|
| 34 |
+
self.write(r)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class SourceHandler(ScorerHandler):
|
| 38 |
+
def get(self):
|
| 39 |
+
sent_id = int(self.get_argument("sent_id"))
|
| 40 |
+
segment_size = None
|
| 41 |
+
if "segment_size" in self.request.arguments:
|
| 42 |
+
string = self.get_argument("segment_size")
|
| 43 |
+
if len(string) > 0:
|
| 44 |
+
segment_size = int(string)
|
| 45 |
+
|
| 46 |
+
r = json.dumps(self.scorer.send_src(int(sent_id), segment_size))
|
| 47 |
+
|
| 48 |
+
self.write(r)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class HypothesisHandler(ScorerHandler):
|
| 52 |
+
def put(self):
|
| 53 |
+
sent_id = int(self.get_argument("sent_id"))
|
| 54 |
+
list_of_tokens = self.request.body.decode("utf-8").strip().split()
|
| 55 |
+
self.scorer.recv_hyp(sent_id, list_of_tokens)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def add_args():
|
| 59 |
+
parser = argparse.ArgumentParser()
|
| 60 |
+
# fmt: off
|
| 61 |
+
parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
|
| 62 |
+
help='Server hostname')
|
| 63 |
+
parser.add_argument('--port', type=int, default=DEFAULT_PORT,
|
| 64 |
+
help='Server port number')
|
| 65 |
+
|
| 66 |
+
args, _ = parser.parse_known_args()
|
| 67 |
+
# fmt: on
|
| 68 |
+
return args
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
|
| 72 |
+
app = web.Application(
|
| 73 |
+
[
|
| 74 |
+
(r"/result", ResultHandler, dict(scorer=scorer)),
|
| 75 |
+
(r"/src", SourceHandler, dict(scorer=scorer)),
|
| 76 |
+
(r"/hypo", HypothesisHandler, dict(scorer=scorer)),
|
| 77 |
+
(r"/", EvalSessionHandler, dict(scorer=scorer)),
|
| 78 |
+
],
|
| 79 |
+
debug=debug,
|
| 80 |
+
)
|
| 81 |
+
app.listen(port, max_buffer_size=1024 ** 3)
|
| 82 |
+
sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
|
| 83 |
+
ioloop.IOLoop.current().start()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
args = add_args()
|
| 88 |
+
scorer = build_scorer(args)
|
| 89 |
+
start_server(scorer, args.hostname, args.port, args.debug)
|
fairseq-0.10.2/examples/simultaneous_translation/models/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 11 |
+
if file.endswith(".py") and not file.startswith("_"):
|
| 12 |
+
model_name = file[: file.find(".py")]
|
| 13 |
+
importlib.import_module(
|
| 14 |
+
"examples.simultaneous_translation.models." + model_name
|
| 15 |
+
)
|
fairseq-0.10.2/examples/simultaneous_translation/models/transformer_monotonic_attention.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
|
| 10 |
+
TransformerMonotonicDecoderLayer,
|
| 11 |
+
TransformerMonotonicEncoderLayer,
|
| 12 |
+
)
|
| 13 |
+
from fairseq.models import register_model, register_model_architecture
|
| 14 |
+
from fairseq.models.transformer import (
|
| 15 |
+
TransformerDecoder,
|
| 16 |
+
TransformerEncoder,
|
| 17 |
+
TransformerModel,
|
| 18 |
+
base_architecture,
|
| 19 |
+
transformer_iwslt_de_en,
|
| 20 |
+
transformer_vaswani_wmt_en_de_big,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
DEFAULT_MAX_SOURCE_POSITIONS = 1024
|
| 25 |
+
DEFAULT_MAX_TARGET_POSITIONS = 1024
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@register_model("transformer_unidirectional")
|
| 29 |
+
class TransformerUnidirectionalModel(TransformerModel):
|
| 30 |
+
@classmethod
|
| 31 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 32 |
+
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@register_model("transformer_monotonic")
|
| 36 |
+
class TransformerMonotonicModel(TransformerModel):
|
| 37 |
+
@classmethod
|
| 38 |
+
def build_encoder(cls, args, src_dict, embed_tokens):
|
| 39 |
+
return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def build_decoder(cls, args, tgt_dict, embed_tokens):
|
| 43 |
+
return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
|
| 44 |
+
|
| 45 |
+
def _indices_from_states(self, states):
|
| 46 |
+
if type(states["indices"]["src"]) == list:
|
| 47 |
+
if next(self.parameters()).is_cuda:
|
| 48 |
+
tensor = torch.cuda.LongTensor
|
| 49 |
+
else:
|
| 50 |
+
tensor = torch.LongTensor
|
| 51 |
+
|
| 52 |
+
src_indices = tensor(
|
| 53 |
+
[states["indices"]["src"][: 1 + states["steps"]["src"]]]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
tgt_indices = tensor(
|
| 57 |
+
[[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
|
| 61 |
+
tgt_indices = states["indices"]["tgt"]
|
| 62 |
+
|
| 63 |
+
return src_indices, None, tgt_indices
|
| 64 |
+
|
| 65 |
+
def predict_from_states(self, states):
|
| 66 |
+
decoder_states = self.decoder.output_layer(states["decoder_features"])
|
| 67 |
+
lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
|
| 68 |
+
|
| 69 |
+
index = lprobs.argmax(dim=-1)
|
| 70 |
+
|
| 71 |
+
token = self.decoder.dictionary.string(index)
|
| 72 |
+
|
| 73 |
+
return token, index[0, 0].item()
|
| 74 |
+
|
| 75 |
+
def decision_from_states(self, states):
|
| 76 |
+
"""
|
| 77 |
+
This funcion take states dictionary as input, and gives the agent
|
| 78 |
+
a decision of whether read a token from server. Moreover, the decoder
|
| 79 |
+
states are also calculated here so we can directly generate a target
|
| 80 |
+
token without recompute every thing
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
self.eval()
|
| 84 |
+
|
| 85 |
+
if len(states["tokens"]["src"]) == 0:
|
| 86 |
+
return 0
|
| 87 |
+
|
| 88 |
+
src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
|
| 89 |
+
|
| 90 |
+
# Update encoder states if needed
|
| 91 |
+
if (
|
| 92 |
+
"encoder_states" not in states
|
| 93 |
+
or states["encoder_states"][0].size(1) <= states["steps"]["src"]
|
| 94 |
+
):
|
| 95 |
+
encoder_out_dict = self.encoder(src_indices, src_lengths)
|
| 96 |
+
states["encoder_states"] = encoder_out_dict
|
| 97 |
+
else:
|
| 98 |
+
encoder_out_dict = states["encoder_states"]
|
| 99 |
+
|
| 100 |
+
# online means we still need tokens to feed the model
|
| 101 |
+
states["model_states"]["online"] = not (
|
| 102 |
+
states["finish_read"]
|
| 103 |
+
and len(states["tokens"]["src"]) == states["steps"]["src"]
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
states["model_states"]["steps"] = states["steps"]
|
| 107 |
+
|
| 108 |
+
x, outputs = self.decoder.forward(
|
| 109 |
+
prev_output_tokens=tgt_indices,
|
| 110 |
+
encoder_out=encoder_out_dict,
|
| 111 |
+
incremental_state=states["model_states"],
|
| 112 |
+
features_only=True,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
states["decoder_features"] = x
|
| 116 |
+
|
| 117 |
+
return outputs["action"]
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class TransformerMonotonicEncoder(TransformerEncoder):
|
| 121 |
+
def __init__(self, args, dictionary, embed_tokens):
|
| 122 |
+
super().__init__(args, dictionary, embed_tokens)
|
| 123 |
+
|
| 124 |
+
self.dictionary = dictionary
|
| 125 |
+
self.layers = nn.ModuleList([])
|
| 126 |
+
self.layers.extend(
|
| 127 |
+
[TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class TransformerMonotonicDecoder(TransformerDecoder):
|
| 132 |
+
"""
|
| 133 |
+
Transformer decoder consisting of *args.decoder_layers* layers. Each layer
|
| 134 |
+
is a :class:`TransformerDecoderLayer`.
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
args (argparse.Namespace): parsed command-line arguments
|
| 138 |
+
dictionary (~fairseq.data.Dictionary): decoding dictionary
|
| 139 |
+
embed_tokens (torch.nn.Embedding): output embedding
|
| 140 |
+
no_encoder_attn (bool, optional): whether to attend to encoder outputs
|
| 141 |
+
(default: False).
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
|
| 145 |
+
super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
|
| 146 |
+
|
| 147 |
+
self.dictionary = dictionary
|
| 148 |
+
self.layers = nn.ModuleList([])
|
| 149 |
+
self.layers.extend(
|
| 150 |
+
[
|
| 151 |
+
TransformerMonotonicDecoderLayer(args, no_encoder_attn)
|
| 152 |
+
for _ in range(args.decoder_layers)
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
def pre_attention(
|
| 157 |
+
self, prev_output_tokens, encoder_out_dict, incremental_state=None
|
| 158 |
+
):
|
| 159 |
+
positions = (
|
| 160 |
+
self.embed_positions(
|
| 161 |
+
prev_output_tokens,
|
| 162 |
+
incremental_state=incremental_state,
|
| 163 |
+
)
|
| 164 |
+
if self.embed_positions is not None
|
| 165 |
+
else None
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
if incremental_state is not None:
|
| 169 |
+
prev_output_tokens = prev_output_tokens[:, -1:]
|
| 170 |
+
if positions is not None:
|
| 171 |
+
positions = positions[:, -1:]
|
| 172 |
+
|
| 173 |
+
# embed tokens and positions
|
| 174 |
+
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
|
| 175 |
+
|
| 176 |
+
if self.project_in_dim is not None:
|
| 177 |
+
x = self.project_in_dim(x)
|
| 178 |
+
|
| 179 |
+
if positions is not None:
|
| 180 |
+
x += positions
|
| 181 |
+
x = self.dropout_module(x)
|
| 182 |
+
|
| 183 |
+
# B x T x C -> T x B x C
|
| 184 |
+
x = x.transpose(0, 1)
|
| 185 |
+
|
| 186 |
+
encoder_out = encoder_out_dict.encoder_out
|
| 187 |
+
encoder_padding_mask = encoder_out_dict.encoder_padding_mask
|
| 188 |
+
|
| 189 |
+
return x, encoder_out, encoder_padding_mask
|
| 190 |
+
|
| 191 |
+
def post_attention(self, x):
|
| 192 |
+
if self.layer_norm:
|
| 193 |
+
x = self.layer_norm(x)
|
| 194 |
+
|
| 195 |
+
# T x B x C -> B x T x C
|
| 196 |
+
x = x.transpose(0, 1)
|
| 197 |
+
|
| 198 |
+
if self.project_out_dim is not None:
|
| 199 |
+
x = self.project_out_dim(x)
|
| 200 |
+
|
| 201 |
+
return x
|
| 202 |
+
|
| 203 |
+
def extract_features(
|
| 204 |
+
self, prev_output_tokens, encoder_out, incremental_state=None, **unused
|
| 205 |
+
):
|
| 206 |
+
"""
|
| 207 |
+
Similar to *forward* but only return features.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
tuple:
|
| 211 |
+
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
|
| 212 |
+
- a dictionary with any model-specific outputs
|
| 213 |
+
"""
|
| 214 |
+
# incremental_state = None
|
| 215 |
+
(x, encoder_outs, encoder_padding_mask) = self.pre_attention(
|
| 216 |
+
prev_output_tokens, encoder_out, incremental_state
|
| 217 |
+
)
|
| 218 |
+
attn = None
|
| 219 |
+
inner_states = [x]
|
| 220 |
+
attn_list = []
|
| 221 |
+
step_list = []
|
| 222 |
+
|
| 223 |
+
for i, layer in enumerate(self.layers):
|
| 224 |
+
|
| 225 |
+
x, attn, _ = layer(
|
| 226 |
+
x=x,
|
| 227 |
+
encoder_out=encoder_outs,
|
| 228 |
+
encoder_padding_mask=encoder_padding_mask,
|
| 229 |
+
incremental_state=incremental_state,
|
| 230 |
+
self_attn_mask=self.buffered_future_mask(x)
|
| 231 |
+
if incremental_state is None
|
| 232 |
+
else None,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
inner_states.append(x)
|
| 236 |
+
attn_list.append(attn)
|
| 237 |
+
|
| 238 |
+
if incremental_state is not None:
|
| 239 |
+
curr_steps = layer.get_steps(incremental_state)
|
| 240 |
+
step_list.append(curr_steps)
|
| 241 |
+
|
| 242 |
+
if incremental_state.get("online", False):
|
| 243 |
+
p_choose = (
|
| 244 |
+
attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
|
| 248 |
+
|
| 249 |
+
if (new_steps >= incremental_state["steps"]["src"]).any():
|
| 250 |
+
# We need to prune the last self_attn saved_state
|
| 251 |
+
# if model decide not to read
|
| 252 |
+
# otherwise there will be duplicated saved_state
|
| 253 |
+
for j in range(i + 1):
|
| 254 |
+
self.layers[j].prune_incremental_state(incremental_state)
|
| 255 |
+
|
| 256 |
+
return x, {"action": 0}
|
| 257 |
+
|
| 258 |
+
if incremental_state is not None and not incremental_state.get("online", False):
|
| 259 |
+
# Here is for fast evaluation
|
| 260 |
+
fastest_step = (
|
| 261 |
+
torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
if "fastest_step" in incremental_state:
|
| 265 |
+
incremental_state["fastest_step"] = torch.cat(
|
| 266 |
+
[incremental_state["fastest_step"], fastest_step], dim=1
|
| 267 |
+
)
|
| 268 |
+
else:
|
| 269 |
+
incremental_state["fastest_step"] = fastest_step
|
| 270 |
+
|
| 271 |
+
x = self.post_attention(x)
|
| 272 |
+
|
| 273 |
+
return x, {
|
| 274 |
+
"action": 1,
|
| 275 |
+
"attn_list": attn_list,
|
| 276 |
+
"step_list": step_list,
|
| 277 |
+
"encoder_out": encoder_out,
|
| 278 |
+
"encoder_padding_mask": encoder_padding_mask,
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
| 282 |
+
super().reorder_incremental_state(incremental_state, new_order)
|
| 283 |
+
if "fastest_step" in incremental_state:
|
| 284 |
+
incremental_state["fastest_step"] = incremental_state[
|
| 285 |
+
"fastest_step"
|
| 286 |
+
].index_select(0, new_order)
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@register_model_architecture("transformer_monotonic", "transformer_monotonic")
|
| 290 |
+
def base_monotonic_rchitecture(args):
|
| 291 |
+
base_architecture(args)
|
| 292 |
+
args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@register_model_architecture(
|
| 296 |
+
"transformer_monotonic", "transformer_monotonic_iwslt_de_en"
|
| 297 |
+
)
|
| 298 |
+
def transformer_monotonic_iwslt_de_en(args):
|
| 299 |
+
transformer_iwslt_de_en(args)
|
| 300 |
+
base_monotonic_rchitecture(args)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
# parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
|
| 304 |
+
@register_model_architecture(
|
| 305 |
+
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
|
| 306 |
+
)
|
| 307 |
+
def transformer_monotonic_vaswani_wmt_en_de_big(args):
|
| 308 |
+
transformer_vaswani_wmt_en_de_big(args)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@register_model_architecture(
|
| 312 |
+
"transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
|
| 313 |
+
)
|
| 314 |
+
def transformer_monotonic_vaswani_wmt_en_fr_big(args):
|
| 315 |
+
transformer_monotonic_vaswani_wmt_en_fr_big(args)
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
@register_model_architecture(
|
| 319 |
+
"transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
|
| 320 |
+
)
|
| 321 |
+
def transformer_unidirectional_iwslt_de_en(args):
|
| 322 |
+
transformer_iwslt_de_en(args)
|
fairseq-0.10.2/examples/simultaneous_translation/modules/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
from fairseq import registry
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
(
|
| 13 |
+
build_monotonic_attention,
|
| 14 |
+
register_monotonic_attention,
|
| 15 |
+
MONOTONIC_ATTENTION_REGISTRY,
|
| 16 |
+
_,
|
| 17 |
+
) = registry.setup_registry("--simul-type")
|
| 18 |
+
|
| 19 |
+
for file in os.listdir(os.path.dirname(__file__)):
|
| 20 |
+
if file.endswith(".py") and not file.startswith("_"):
|
| 21 |
+
model_name = file[: file.find(".py")]
|
| 22 |
+
importlib.import_module(
|
| 23 |
+
"examples.simultaneous_translation.modules." + model_name
|
| 24 |
+
)
|
fairseq-0.10.2/examples/simultaneous_translation/modules/monotonic_multihead_attention.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the MIT license found in the
|
| 4 |
+
# LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from examples.simultaneous_translation.utils.functions import (
|
| 12 |
+
exclusive_cumprod,
|
| 13 |
+
lengths_to_mask,
|
| 14 |
+
)
|
| 15 |
+
from fairseq import utils
|
| 16 |
+
from fairseq.incremental_decoding_utils import with_incremental_state
|
| 17 |
+
from fairseq.modules import MultiheadAttention
|
| 18 |
+
from fairseq.utils import convert_padding_direction
|
| 19 |
+
|
| 20 |
+
from . import register_monotonic_attention
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@with_incremental_state
|
| 24 |
+
class MonotonicAttention(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
Abstract class of monotonic attentions
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, args):
|
| 30 |
+
self.eps = args.attention_eps
|
| 31 |
+
self.mass_preservation = args.mass_preservation
|
| 32 |
+
|
| 33 |
+
self.noise_mean = args.noise_mean
|
| 34 |
+
self.noise_var = args.noise_var
|
| 35 |
+
|
| 36 |
+
self.energy_bias_init = args.energy_bias_init
|
| 37 |
+
self.energy_bias = (
|
| 38 |
+
nn.Parameter(self.energy_bias_init * torch.ones([1]))
|
| 39 |
+
if args.energy_bias is True
|
| 40 |
+
else 0
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
@staticmethod
|
| 44 |
+
def add_args(parser):
|
| 45 |
+
# fmt: off
|
| 46 |
+
parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation",
|
| 47 |
+
help='Do not stay on the last token when decoding')
|
| 48 |
+
parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation",
|
| 49 |
+
help='Stay on the last token when decoding')
|
| 50 |
+
parser.set_defaults(mass_preservation=True)
|
| 51 |
+
|
| 52 |
+
parser.add_argument('--noise-var', type=float, default=1.0,
|
| 53 |
+
help='Variance of discretness noise')
|
| 54 |
+
parser.add_argument('--noise-mean', type=float, default=0.0,
|
| 55 |
+
help='Mean of discretness noise')
|
| 56 |
+
parser.add_argument('--energy-bias', action="store_true", default=False,
|
| 57 |
+
help='Bias for energy')
|
| 58 |
+
parser.add_argument('--energy-bias-init', type=float, default=-2.0,
|
| 59 |
+
help='Initial value of the bias for energy')
|
| 60 |
+
parser.add_argument('--attention-eps', type=float, default=1e-6,
|
| 61 |
+
help='Epsilon when calculating expected attention')
|
| 62 |
+
# fmt: on
|
| 63 |
+
|
| 64 |
+
def p_choose(self, *args):
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
|
| 67 |
+
def input_projections(self, *args):
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
def attn_energy(self, q_proj, k_proj, key_padding_mask=None):
|
| 71 |
+
"""
|
| 72 |
+
Calculating monotonic energies
|
| 73 |
+
|
| 74 |
+
============================================================
|
| 75 |
+
Expected input size
|
| 76 |
+
q_proj: bsz * num_heads, tgt_len, self.head_dim
|
| 77 |
+
k_proj: bsz * num_heads, src_len, self.head_dim
|
| 78 |
+
key_padding_mask: bsz, src_len
|
| 79 |
+
attn_mask: tgt_len, src_len
|
| 80 |
+
"""
|
| 81 |
+
bsz, tgt_len, embed_dim = q_proj.size()
|
| 82 |
+
bsz = bsz // self.num_heads
|
| 83 |
+
src_len = k_proj.size(1)
|
| 84 |
+
|
| 85 |
+
attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
|
| 86 |
+
|
| 87 |
+
attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
|
| 88 |
+
|
| 89 |
+
if key_padding_mask is not None:
|
| 90 |
+
attn_energy = attn_energy.masked_fill(
|
| 91 |
+
key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
|
| 92 |
+
float("-inf"),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
return attn_energy
|
| 96 |
+
|
| 97 |
+
def expected_alignment_train(self, p_choose, key_padding_mask):
|
| 98 |
+
"""
|
| 99 |
+
Calculating expected alignment for MMA
|
| 100 |
+
Mask is not need because p_choose will be 0 if masked
|
| 101 |
+
|
| 102 |
+
q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
|
| 103 |
+
a_ij = p_ij q_ij
|
| 104 |
+
|
| 105 |
+
parellel solution:
|
| 106 |
+
ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
|
| 107 |
+
|
| 108 |
+
============================================================
|
| 109 |
+
Expected input size
|
| 110 |
+
p_choose: bsz * num_heads, tgt_len, src_len
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
# p_choose: bsz * num_heads, tgt_len, src_len
|
| 114 |
+
bsz_num_heads, tgt_len, src_len = p_choose.size()
|
| 115 |
+
|
| 116 |
+
# cumprod_1mp : bsz * num_heads, tgt_len, src_len
|
| 117 |
+
cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
|
| 118 |
+
cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
|
| 119 |
+
|
| 120 |
+
init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
|
| 121 |
+
init_attention[:, :, 0] = 1.0
|
| 122 |
+
|
| 123 |
+
previous_attn = [init_attention]
|
| 124 |
+
|
| 125 |
+
for i in range(tgt_len):
|
| 126 |
+
# p_choose: bsz * num_heads, tgt_len, src_len
|
| 127 |
+
# cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
|
| 128 |
+
# previous_attn[i]: bsz * num_heads, 1, src_len
|
| 129 |
+
# alpha_i: bsz * num_heads, src_len
|
| 130 |
+
alpha_i = (
|
| 131 |
+
p_choose[:, i]
|
| 132 |
+
* cumprod_1mp[:, i]
|
| 133 |
+
* torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
|
| 134 |
+
).clamp(0, 1.0)
|
| 135 |
+
previous_attn.append(alpha_i.unsqueeze(1))
|
| 136 |
+
|
| 137 |
+
# alpha: bsz * num_heads, tgt_len, src_len
|
| 138 |
+
alpha = torch.cat(previous_attn[1:], dim=1)
|
| 139 |
+
|
| 140 |
+
if self.mass_preservation:
|
| 141 |
+
# Last token has the residual probabilities
|
| 142 |
+
alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
|
| 143 |
+
|
| 144 |
+
assert not torch.isnan(alpha).any(), "NaN detected in alpha."
|
| 145 |
+
|
| 146 |
+
return alpha
|
| 147 |
+
|
| 148 |
+
def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state):
|
| 149 |
+
"""
|
| 150 |
+
Calculating mo alignment for MMA during inference time
|
| 151 |
+
|
| 152 |
+
============================================================
|
| 153 |
+
Expected input size
|
| 154 |
+
p_choose: bsz * num_heads, tgt_len, src_len
|
| 155 |
+
key_padding_mask: bsz * src_len
|
| 156 |
+
incremental_state: dict
|
| 157 |
+
"""
|
| 158 |
+
# p_choose: bsz * self.num_heads, src_len
|
| 159 |
+
bsz_num_heads, tgt_len, src_len = p_choose.size()
|
| 160 |
+
# One token at a time
|
| 161 |
+
assert tgt_len == 1
|
| 162 |
+
p_choose = p_choose[:, 0, :]
|
| 163 |
+
|
| 164 |
+
monotonic_cache = self._get_monotonic_buffer(incremental_state)
|
| 165 |
+
|
| 166 |
+
# prev_monotonic_step: bsz, num_heads
|
| 167 |
+
bsz = bsz_num_heads // self.num_heads
|
| 168 |
+
prev_monotonic_step = monotonic_cache.get(
|
| 169 |
+
"step", p_choose.new_zeros([bsz, self.num_heads]).long()
|
| 170 |
+
)
|
| 171 |
+
bsz, num_heads = prev_monotonic_step.size()
|
| 172 |
+
assert num_heads == self.num_heads
|
| 173 |
+
assert bsz * num_heads == bsz_num_heads
|
| 174 |
+
|
| 175 |
+
# p_choose: bsz, num_heads, src_len
|
| 176 |
+
p_choose = p_choose.view(bsz, num_heads, src_len)
|
| 177 |
+
|
| 178 |
+
if key_padding_mask is not None:
|
| 179 |
+
src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
|
| 180 |
+
else:
|
| 181 |
+
src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
|
| 182 |
+
|
| 183 |
+
# src_lengths: bsz, num_heads
|
| 184 |
+
src_lengths = src_lengths.expand_as(prev_monotonic_step)
|
| 185 |
+
# new_monotonic_step: bsz, num_heads
|
| 186 |
+
new_monotonic_step = prev_monotonic_step
|
| 187 |
+
|
| 188 |
+
step_offset = 0
|
| 189 |
+
if key_padding_mask is not None:
|
| 190 |
+
if key_padding_mask[:, 0].any():
|
| 191 |
+
# left_pad_source = True:
|
| 192 |
+
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
| 193 |
+
|
| 194 |
+
max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
|
| 195 |
+
|
| 196 |
+
# finish_read: bsz, num_heads
|
| 197 |
+
finish_read = new_monotonic_step.eq(max_steps)
|
| 198 |
+
|
| 199 |
+
while finish_read.sum().item() < bsz * self.num_heads:
|
| 200 |
+
# p_choose: bsz * self.num_heads, src_len
|
| 201 |
+
# only choose the p at monotonic steps
|
| 202 |
+
# p_choose_i: bsz , self.num_heads
|
| 203 |
+
p_choose_i = (
|
| 204 |
+
p_choose.gather(
|
| 205 |
+
2,
|
| 206 |
+
(step_offset + new_monotonic_step)
|
| 207 |
+
.unsqueeze(2)
|
| 208 |
+
.clamp(0, src_len - 1),
|
| 209 |
+
)
|
| 210 |
+
).squeeze(2)
|
| 211 |
+
|
| 212 |
+
action = (
|
| 213 |
+
(p_choose_i < 0.5)
|
| 214 |
+
.type_as(prev_monotonic_step)
|
| 215 |
+
.masked_fill(finish_read, 0)
|
| 216 |
+
)
|
| 217 |
+
# 1 x bsz
|
| 218 |
+
# sample actions on unfinished seq
|
| 219 |
+
# 1 means stay, finish reading
|
| 220 |
+
# 0 means leave, continue reading
|
| 221 |
+
# dist = torch.distributions.bernoulli.Bernoulli(p_choose)
|
| 222 |
+
# action = dist.sample().type_as(finish_read) * (1 - finish_read)
|
| 223 |
+
|
| 224 |
+
new_monotonic_step += action
|
| 225 |
+
|
| 226 |
+
finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
|
| 227 |
+
# finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read
|
| 228 |
+
|
| 229 |
+
monotonic_cache["step"] = new_monotonic_step
|
| 230 |
+
|
| 231 |
+
# alpha: bsz * num_heads, 1, src_len
|
| 232 |
+
# new_monotonic_step: bsz, num_heads
|
| 233 |
+
alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
|
| 234 |
+
1,
|
| 235 |
+
(step_offset + new_monotonic_step)
|
| 236 |
+
.view(bsz * self.num_heads, 1)
|
| 237 |
+
.clamp(0, src_len - 1),
|
| 238 |
+
1,
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if not self.mass_preservation:
|
| 242 |
+
alpha = alpha.masked_fill(
|
| 243 |
+
(new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
alpha = alpha.unsqueeze(1)
|
| 247 |
+
|
| 248 |
+
self._set_monotonic_buffer(incremental_state, monotonic_cache)
|
| 249 |
+
|
| 250 |
+
return alpha
|
| 251 |
+
|
| 252 |
+
def v_proj_output(self, value):
|
| 253 |
+
raise NotImplementedError
|
| 254 |
+
|
| 255 |
+
def forward(
|
| 256 |
+
self,
|
| 257 |
+
query,
|
| 258 |
+
key,
|
| 259 |
+
value,
|
| 260 |
+
key_padding_mask=None,
|
| 261 |
+
incremental_state=None,
|
| 262 |
+
*args,
|
| 263 |
+
**kwargs,
|
| 264 |
+
):
|
| 265 |
+
|
| 266 |
+
tgt_len, bsz, embed_dim = query.size()
|
| 267 |
+
src_len = value.size(0)
|
| 268 |
+
|
| 269 |
+
# stepwise prob
|
| 270 |
+
# p_choose: bsz * self.num_heads, tgt_len, src_len
|
| 271 |
+
p_choose = self.p_choose(query, key, key_padding_mask)
|
| 272 |
+
|
| 273 |
+
# expected alignment alpha
|
| 274 |
+
# bsz * self.num_heads, tgt_len, src_len
|
| 275 |
+
if incremental_state is not None:
|
| 276 |
+
alpha = self.expected_alignment_infer(
|
| 277 |
+
p_choose, key_padding_mask, incremental_state
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
alpha = self.expected_alignment_train(p_choose, key_padding_mask)
|
| 281 |
+
|
| 282 |
+
# expected attention beta
|
| 283 |
+
# bsz * self.num_heads, tgt_len, src_len
|
| 284 |
+
beta = self.expected_attention(
|
| 285 |
+
alpha, query, key, value, key_padding_mask, incremental_state
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
attn_weights = beta
|
| 289 |
+
|
| 290 |
+
v_proj = self.v_proj_output(value)
|
| 291 |
+
attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
|
| 292 |
+
|
| 293 |
+
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
|
| 294 |
+
|
| 295 |
+
attn = self.out_proj(attn)
|
| 296 |
+
|
| 297 |
+
beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
|
| 298 |
+
alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
|
| 299 |
+
p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
|
| 300 |
+
|
| 301 |
+
return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose}
|
| 302 |
+
|
| 303 |
+
def reorder_incremental_state(self, incremental_state, new_order):
|
| 304 |
+
"""Reorder buffered internal state (for incremental generation)."""
|
| 305 |
+
super().reorder_incremental_state(incremental_state, new_order)
|
| 306 |
+
input_buffer = self._get_monotonic_buffer(incremental_state)
|
| 307 |
+
if input_buffer is not None:
|
| 308 |
+
for k in input_buffer.keys():
|
| 309 |
+
input_buffer[k] = input_buffer[k].index_select(0, new_order)
|
| 310 |
+
self._set_monotonic_buffer(incremental_state, input_buffer)
|
| 311 |
+
|
| 312 |
+
def _get_monotonic_buffer(self, incremental_state):
|
| 313 |
+
return (
|
| 314 |
+
utils.get_incremental_state(
|
| 315 |
+
self,
|
| 316 |
+
incremental_state,
|
| 317 |
+
"monotonic",
|
| 318 |
+
)
|
| 319 |
+
or {}
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def _set_monotonic_buffer(self, incremental_state, buffer):
|
| 323 |
+
utils.set_incremental_state(
|
| 324 |
+
self,
|
| 325 |
+
incremental_state,
|
| 326 |
+
"monotonic",
|
| 327 |
+
buffer,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def get_pointer(self, incremental_state):
|
| 331 |
+
return (
|
| 332 |
+
utils.get_incremental_state(
|
| 333 |
+
self,
|
| 334 |
+
incremental_state,
|
| 335 |
+
"monotonic",
|
| 336 |
+
)
|
| 337 |
+
or {}
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
def get_fastest_pointer(self, incremental_state):
|
| 341 |
+
return self.get_pointer(incremental_state)["step"].max(0)[0]
|
| 342 |
+
|
| 343 |
+
def set_pointer(self, incremental_state, p_choose):
|
| 344 |
+
curr_pointer = self.get_pointer(incremental_state)
|
| 345 |
+
if len(curr_pointer) == 0:
|
| 346 |
+
buffer = torch.zeros_like(p_choose)
|
| 347 |
+
else:
|
| 348 |
+
buffer = self.get_pointer(incremental_state)["step"]
|
| 349 |
+
|
| 350 |
+
buffer += (p_choose < 0.5).type_as(buffer)
|
| 351 |
+
|
| 352 |
+
utils.set_incremental_state(
|
| 353 |
+
self,
|
| 354 |
+
incremental_state,
|
| 355 |
+
"monotonic",
|
| 356 |
+
{"step": buffer},
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@register_monotonic_attention("hard_aligned")
|
| 361 |
+
class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
|
| 362 |
+
def __init__(self, args):
|
| 363 |
+
MultiheadAttention.__init__(
|
| 364 |
+
self,
|
| 365 |
+
embed_dim=args.decoder_embed_dim,
|
| 366 |
+
num_heads=args.decoder_attention_heads,
|
| 367 |
+
kdim=getattr(args, "encoder_embed_dim", None),
|
| 368 |
+
vdim=getattr(args, "encoder_embed_dim", None),
|
| 369 |
+
dropout=args.attention_dropout,
|
| 370 |
+
encoder_decoder_attention=True,
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
MonotonicAttention.__init__(self, args)
|
| 374 |
+
|
| 375 |
+
self.k_in_proj = {"monotonic": self.k_proj}
|
| 376 |
+
self.q_in_proj = {"monotonic": self.q_proj}
|
| 377 |
+
self.v_in_proj = {"output": self.v_proj}
|
| 378 |
+
|
| 379 |
+
def input_projections(self, query, key, value, name):
|
| 380 |
+
"""
|
| 381 |
+
Prepare inputs for multihead attention
|
| 382 |
+
|
| 383 |
+
============================================================
|
| 384 |
+
Expected input size
|
| 385 |
+
query: tgt_len, bsz, embed_dim
|
| 386 |
+
key: src_len, bsz, embed_dim
|
| 387 |
+
value: src_len, bsz, embed_dim
|
| 388 |
+
name: monotonic or soft
|
| 389 |
+
"""
|
| 390 |
+
|
| 391 |
+
if query is not None:
|
| 392 |
+
bsz = query.size(1)
|
| 393 |
+
q = self.q_in_proj[name](query)
|
| 394 |
+
q *= self.scaling
|
| 395 |
+
q = (
|
| 396 |
+
q.contiguous()
|
| 397 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 398 |
+
.transpose(0, 1)
|
| 399 |
+
)
|
| 400 |
+
else:
|
| 401 |
+
q = None
|
| 402 |
+
|
| 403 |
+
if key is not None:
|
| 404 |
+
bsz = key.size(1)
|
| 405 |
+
k = self.k_in_proj[name](key)
|
| 406 |
+
k = (
|
| 407 |
+
k.contiguous()
|
| 408 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 409 |
+
.transpose(0, 1)
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
k = None
|
| 413 |
+
|
| 414 |
+
if value is not None:
|
| 415 |
+
bsz = value.size(1)
|
| 416 |
+
v = self.v_in_proj[name](value)
|
| 417 |
+
v = (
|
| 418 |
+
v.contiguous()
|
| 419 |
+
.view(-1, bsz * self.num_heads, self.head_dim)
|
| 420 |
+
.transpose(0, 1)
|
| 421 |
+
)
|
| 422 |
+
else:
|
| 423 |
+
v = None
|
| 424 |
+
|
| 425 |
+
return q, k, v
|
| 426 |
+
|
| 427 |
+
def p_choose(self, query, key, key_padding_mask=None):
|
| 428 |
+
"""
|
| 429 |
+
Calculating step wise prob for reading and writing
|
| 430 |
+
1 to read, 0 to write
|
| 431 |
+
|
| 432 |
+
============================================================
|
| 433 |
+
Expected input size
|
| 434 |
+
query: bsz, tgt_len, embed_dim
|
| 435 |
+
key: bsz, src_len, embed_dim
|
| 436 |
+
value: bsz, src_len, embed_dim
|
| 437 |
+
key_padding_mask: bsz, src_len
|
| 438 |
+
attn_mask: bsz, src_len
|
| 439 |
+
query: bsz, tgt_len, embed_dim
|
| 440 |
+
"""
|
| 441 |
+
|
| 442 |
+
# prepare inputs
|
| 443 |
+
q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic")
|
| 444 |
+
|
| 445 |
+
# attention energy
|
| 446 |
+
attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
|
| 447 |
+
|
| 448 |
+
noise = 0
|
| 449 |
+
|
| 450 |
+
if self.training:
|
| 451 |
+
# add noise here to encourage discretness
|
| 452 |
+
noise = (
|
| 453 |
+
torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
|
| 454 |
+
.type_as(attn_energy)
|
| 455 |
+
.to(attn_energy.device)
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
p_choose = torch.sigmoid(attn_energy + noise)
|
| 459 |
+
_, _, tgt_len, src_len = p_choose.size()
|
| 460 |
+
|
| 461 |
+
# p_choose: bsz * self.num_heads, tgt_len, src_len
|
| 462 |
+
return p_choose.view(-1, tgt_len, src_len)
|
| 463 |
+
|
| 464 |
+
def expected_attention(self, alpha, *args):
|
| 465 |
+
"""
|
| 466 |
+
For MMA-H, beta = alpha
|
| 467 |
+
"""
|
| 468 |
+
return alpha
|
| 469 |
+
|
| 470 |
+
def v_proj_output(self, value):
|
| 471 |
+
_, _, v_proj = self.input_projections(None, None, value, "output")
|
| 472 |
+
return v_proj
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
@register_monotonic_attention("infinite_lookback")
|
| 476 |
+
class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard):
|
| 477 |
+
def __init__(self, args):
|
| 478 |
+
super().__init__(args)
|
| 479 |
+
self.init_soft_attention()
|
| 480 |
+
|
| 481 |
+
def init_soft_attention(self):
|
| 482 |
+
self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
|
| 483 |
+
self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
|
| 484 |
+
self.k_in_proj["soft"] = self.k_proj_soft
|
| 485 |
+
self.q_in_proj["soft"] = self.q_proj_soft
|
| 486 |
+
|
| 487 |
+
if self.qkv_same_dim:
|
| 488 |
+
# Empirically observed the convergence to be much better with
|
| 489 |
+
# the scaled initialization
|
| 490 |
+
nn.init.xavier_uniform_(
|
| 491 |
+
self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
|
| 492 |
+
)
|
| 493 |
+
nn.init.xavier_uniform_(
|
| 494 |
+
self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
|
| 498 |
+
nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
|
| 499 |
+
|
| 500 |
+
def expected_attention(
|
| 501 |
+
self, alpha, query, key, value, key_padding_mask, incremental_state
|
| 502 |
+
):
|
| 503 |
+
# monotonic attention, we will calculate milk here
|
| 504 |
+
bsz_x_num_heads, tgt_len, src_len = alpha.size()
|
| 505 |
+
bsz = int(bsz_x_num_heads / self.num_heads)
|
| 506 |
+
|
| 507 |
+
q, k, _ = self.input_projections(query, key, None, "soft")
|
| 508 |
+
soft_energy = self.attn_energy(q, k, key_padding_mask)
|
| 509 |
+
|
| 510 |
+
assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len]
|
| 511 |
+
|
| 512 |
+
soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
|
| 513 |
+
|
| 514 |
+
if incremental_state is not None:
|
| 515 |
+
monotonic_cache = self._get_monotonic_buffer(incremental_state)
|
| 516 |
+
monotonic_step = monotonic_cache["step"] + 1
|
| 517 |
+
step_offset = 0
|
| 518 |
+
if key_padding_mask is not None:
|
| 519 |
+
if key_padding_mask[:, 0].any():
|
| 520 |
+
# left_pad_source = True:
|
| 521 |
+
step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
|
| 522 |
+
monotonic_step += step_offset
|
| 523 |
+
mask = lengths_to_mask(
|
| 524 |
+
monotonic_step.view(-1), soft_energy.size(2), 1
|
| 525 |
+
).unsqueeze(1)
|
| 526 |
+
|
| 527 |
+
soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
|
| 528 |
+
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
|
| 529 |
+
exp_soft_energy = torch.exp(soft_energy)
|
| 530 |
+
exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
|
| 531 |
+
beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
|
| 532 |
+
|
| 533 |
+
else:
|
| 534 |
+
# bsz * num_heads, tgt_len, src_len
|
| 535 |
+
soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
|
| 536 |
+
exp_soft_energy = torch.exp(soft_energy)
|
| 537 |
+
exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2)
|
| 538 |
+
|
| 539 |
+
if key_padding_mask is not None:
|
| 540 |
+
if key_padding_mask.any():
|
| 541 |
+
exp_soft_energy_cumsum = (
|
| 542 |
+
exp_soft_energy_cumsum.view(
|
| 543 |
+
-1, self.num_heads, tgt_len, src_len
|
| 544 |
+
)
|
| 545 |
+
.masked_fill(
|
| 546 |
+
key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
|
| 547 |
+
)
|
| 548 |
+
.view(-1, tgt_len, src_len)
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
inner_items = alpha / exp_soft_energy_cumsum
|
| 552 |
+
|
| 553 |
+
beta = exp_soft_energy * torch.cumsum(
|
| 554 |
+
inner_items.flip(dims=[2]), dim=2
|
| 555 |
+
).flip(dims=[2])
|
| 556 |
+
|
| 557 |
+
beta = self.dropout_module(beta)
|
| 558 |
+
|
| 559 |
+
assert not torch.isnan(beta).any(), "NaN detected in beta."
|
| 560 |
+
|
| 561 |
+
return beta
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
@register_monotonic_attention("waitk")
|
| 565 |
+
class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback):
|
| 566 |
+
def __init__(self, args):
|
| 567 |
+
super().__init__(args)
|
| 568 |
+
self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
|
| 569 |
+
self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
|
| 570 |
+
self.waitk_lagging = args.waitk_lagging
|
| 571 |
+
assert (
|
| 572 |
+
self.waitk_lagging > 0
|
| 573 |
+
), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
|
| 574 |
+
|
| 575 |
+
@staticmethod
|
| 576 |
+
def add_args(parser):
|
| 577 |
+
super(
|
| 578 |
+
MonotonicMultiheadAttentionWaitk,
|
| 579 |
+
MonotonicMultiheadAttentionWaitk,
|
| 580 |
+
).add_args(parser)
|
| 581 |
+
|
| 582 |
+
parser.add_argument(
|
| 583 |
+
"--waitk-lagging", type=int, required=True, help="Wait k lagging"
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
def p_choose(
|
| 587 |
+
self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
|
| 588 |
+
):
|
| 589 |
+
"""
|
| 590 |
+
query: bsz, tgt_len
|
| 591 |
+
key: bsz, src_len
|
| 592 |
+
key_padding_mask: bsz, src_len
|
| 593 |
+
"""
|
| 594 |
+
src_len, bsz, _ = key.size()
|
| 595 |
+
tgt_len, bsz, _ = query.size()
|
| 596 |
+
p_choose = query.new_ones(bsz, tgt_len, src_len)
|
| 597 |
+
p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1)
|
| 598 |
+
p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1)
|
| 599 |
+
|
| 600 |
+
if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
|
| 601 |
+
# Left pad source
|
| 602 |
+
# add -1 to the end
|
| 603 |
+
p_choose = p_choose.masked_fill(
|
| 604 |
+
key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
|
| 605 |
+
)
|
| 606 |
+
p_choose = convert_padding_direction(
|
| 607 |
+
p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
|
| 608 |
+
)
|
| 609 |
+
p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
|
| 610 |
+
# remove -1
|
| 611 |
+
p_choose[p_choose.eq(-1)] = 0
|
| 612 |
+
|
| 613 |
+
# Extend to each head
|
| 614 |
+
p_choose = (
|
| 615 |
+
p_choose.contiguous()
|
| 616 |
+
.unsqueeze(1)
|
| 617 |
+
.expand(-1, self.num_heads, -1, -1)
|
| 618 |
+
.contiguous()
|
| 619 |
+
.view(-1, tgt_len, src_len)
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
return p_choose
|