sleepyhead111 commited on
Commit
74f3e76
·
verified ·
1 Parent(s): fdc723d

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. fairseq-0.10.2/.github/ISSUE_TEMPLATE.md +3 -0
  2. fairseq-0.10.2/.github/ISSUE_TEMPLATE/bug_report.md +43 -0
  3. fairseq-0.10.2/.github/ISSUE_TEMPLATE/documentation.md +15 -0
  4. fairseq-0.10.2/.github/ISSUE_TEMPLATE/feature_request.md +24 -0
  5. fairseq-0.10.2/.github/ISSUE_TEMPLATE/how-to-question.md +33 -0
  6. fairseq-0.10.2/config/config.yaml +7 -0
  7. fairseq-0.10.2/config/config_eval_lm.yaml +7 -0
  8. fairseq-0.10.2/config/criterion/adaptive_loss.yaml +3 -0
  9. fairseq-0.10.2/config/criterion/cross_entropy.yaml +3 -0
  10. fairseq-0.10.2/config/lr_scheduler/cosine.yaml +7 -0
  11. fairseq-0.10.2/config/lr_scheduler/inverse_sqrt.yaml +3 -0
  12. fairseq-0.10.2/config/model/transformer_lm.yaml +36 -0
  13. fairseq-0.10.2/config/model/transformer_lm_baevski_gbw.yaml +36 -0
  14. fairseq-0.10.2/config/model/transformer_lm_baevski_wiki103.yaml +36 -0
  15. fairseq-0.10.2/config/model/transformer_lm_big.yaml +36 -0
  16. fairseq-0.10.2/config/model/transformer_lm_gbw.yaml +36 -0
  17. fairseq-0.10.2/config/model/transformer_lm_gpt.yaml +36 -0
  18. fairseq-0.10.2/config/model/transformer_lm_gpt2_big.yaml +36 -0
  19. fairseq-0.10.2/config/model/transformer_lm_gpt2_medium.yaml +36 -0
  20. fairseq-0.10.2/config/model/transformer_lm_gpt2_small.yaml +36 -0
  21. fairseq-0.10.2/config/model/transformer_lm_wiki103.yaml +36 -0
  22. fairseq-0.10.2/config/optimizer/adam.yaml +5 -0
  23. fairseq-0.10.2/config/optimizer/nag.yaml +3 -0
  24. fairseq-0.10.2/config/params/eval_lm_params.yaml +105 -0
  25. fairseq-0.10.2/config/params/training_params.yaml +95 -0
  26. fairseq-0.10.2/config/task/language_modeling.yaml +10 -0
  27. fairseq-0.10.2/examples/noisychannel/README.md +72 -0
  28. fairseq-0.10.2/examples/noisychannel/__init__.py +6 -0
  29. fairseq-0.10.2/examples/noisychannel/rerank.py +422 -0
  30. fairseq-0.10.2/examples/noisychannel/rerank_generate.py +397 -0
  31. fairseq-0.10.2/examples/noisychannel/rerank_options.py +149 -0
  32. fairseq-0.10.2/examples/noisychannel/rerank_score_bw.py +143 -0
  33. fairseq-0.10.2/examples/noisychannel/rerank_tune.py +102 -0
  34. fairseq-0.10.2/examples/noisychannel/rerank_utils.py +850 -0
  35. fairseq-0.10.2/examples/paraphraser/README.md +46 -0
  36. fairseq-0.10.2/examples/paraphraser/paraphrase.py +85 -0
  37. fairseq-0.10.2/examples/simultaneous_translation/README.md +106 -0
  38. fairseq-0.10.2/examples/simultaneous_translation/__init__.py +6 -0
  39. fairseq-0.10.2/examples/simultaneous_translation/docs/baseline.md +178 -0
  40. fairseq-0.10.2/examples/simultaneous_translation/docs/evaluation.md +115 -0
  41. fairseq-0.10.2/examples/simultaneous_translation/eval/__init__.py +4 -0
  42. fairseq-0.10.2/examples/simultaneous_translation/eval/agents/word_splitter.py +91 -0
  43. fairseq-0.10.2/examples/simultaneous_translation/eval/client.py +100 -0
  44. fairseq-0.10.2/examples/simultaneous_translation/eval/eval_latency.py +78 -0
  45. fairseq-0.10.2/examples/simultaneous_translation/eval/evaluate.py +81 -0
  46. fairseq-0.10.2/examples/simultaneous_translation/eval/server.py +89 -0
  47. fairseq-0.10.2/examples/simultaneous_translation/models/__init__.py +15 -0
  48. fairseq-0.10.2/examples/simultaneous_translation/models/transformer_monotonic_attention.py +322 -0
  49. fairseq-0.10.2/examples/simultaneous_translation/modules/__init__.py +24 -0
  50. fairseq-0.10.2/examples/simultaneous_translation/modules/monotonic_multihead_attention.py +622 -0
fairseq-0.10.2/.github/ISSUE_TEMPLATE.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ ## 👉 [Please follow one of these issue templates](https://github.com/pytorch/fairseq/issues/new/choose) 👈
2
+
3
+ Note: to keep the backlog clean and actionable, issues may be immediately closed if they do not follow one of the above issue templates.
fairseq-0.10.2/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🐛 Bug Report
3
+ about: Submit a bug report to help us improve
4
+ labels: 'bug, needs triage'
5
+ ---
6
+
7
+ ## 🐛 Bug
8
+
9
+ <!-- A clear and concise description of what the bug is. -->
10
+
11
+ ### To Reproduce
12
+
13
+ Steps to reproduce the behavior (**always include the command you ran**):
14
+
15
+ 1. Run cmd '....'
16
+ 2. See error
17
+
18
+ <!-- If you have a code sample, error messages, stack traces, please provide it here as well -->
19
+
20
+
21
+ #### Code sample
22
+ <!-- Ideally attach a minimal code sample to reproduce the decried issue.
23
+ Minimal means having the shortest code but still preserving the bug. -->
24
+
25
+ ### Expected behavior
26
+
27
+ <!-- A clear and concise description of what you expected to happen. -->
28
+
29
+ ### Environment
30
+
31
+ - fairseq Version (e.g., 1.0 or master):
32
+ - PyTorch Version (e.g., 1.0)
33
+ - OS (e.g., Linux):
34
+ - How you installed fairseq (`pip`, source):
35
+ - Build command you used (if compiling from source):
36
+ - Python version:
37
+ - CUDA/cuDNN version:
38
+ - GPU models and configuration:
39
+ - Any other relevant information:
40
+
41
+ ### Additional context
42
+
43
+ <!-- Add any other context about the problem here. -->
fairseq-0.10.2/.github/ISSUE_TEMPLATE/documentation.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 📚 Documentation/Typos
3
+ about: Report an issue related to documentation or a typo
4
+ labels: 'documentation, needs triage'
5
+ ---
6
+
7
+ ## 📚 Documentation
8
+
9
+ For typos and doc fixes, please go ahead and:
10
+
11
+ 1. Create an issue.
12
+ 2. Fix the typo.
13
+ 3. Submit a PR.
14
+
15
+ Thanks!
fairseq-0.10.2/.github/ISSUE_TEMPLATE/feature_request.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: 🚀 Feature Request
3
+ about: Submit a proposal/request for a new feature
4
+ labels: 'enhancement, help wanted, needs triage'
5
+ ---
6
+
7
+ ## 🚀 Feature Request
8
+ <!-- A clear and concise description of the feature proposal -->
9
+
10
+ ### Motivation
11
+
12
+ <!-- Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too -->
13
+
14
+ ### Pitch
15
+
16
+ <!-- A clear and concise description of what you want to happen. -->
17
+
18
+ ### Alternatives
19
+
20
+ <!-- A clear and concise description of any alternative solutions or features you've considered, if any. -->
21
+
22
+ ### Additional context
23
+
24
+ <!-- Add any other context or screenshots about the feature request here. -->
fairseq-0.10.2/.github/ISSUE_TEMPLATE/how-to-question.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: ❓ Questions/Help
3
+ about: If you have questions, please first search existing issues and docs
4
+ labels: 'question, needs triage'
5
+ ---
6
+
7
+ ## ❓ Questions and Help
8
+
9
+ ### Before asking:
10
+ 1. search the issues.
11
+ 2. search the docs.
12
+
13
+ <!-- If you still can't find what you need: -->
14
+
15
+ #### What is your question?
16
+
17
+ #### Code
18
+
19
+ <!-- Please paste a code snippet if your question requires it! -->
20
+
21
+ #### What have you tried?
22
+
23
+ #### What's your environment?
24
+
25
+ - fairseq Version (e.g., 1.0 or master):
26
+ - PyTorch Version (e.g., 1.0)
27
+ - OS (e.g., Linux):
28
+ - How you installed fairseq (`pip`, source):
29
+ - Build command you used (if compiling from source):
30
+ - Python version:
31
+ - CUDA/cuDNN version:
32
+ - GPU models and configuration:
33
+ - Any other relevant information:
fairseq-0.10.2/config/config.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - params: training_params
3
+ - task: language_modeling
4
+ - model: transformer_lm
5
+ - criterion: cross_entropy
6
+ - optimizer: adam
7
+ - lr_scheduler: inverse_sqrt
fairseq-0.10.2/config/config_eval_lm.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - params: eval_lm_params
3
+ - task: language_modeling
4
+ - model: transformer_lm
5
+ - criterion: cross_entropy
6
+ - optimizer: adam
7
+ - lr_scheduler: inverse_sqrt
fairseq-0.10.2/config/criterion/adaptive_loss.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package _group_
2
+ sentence_avg: ${params.optimization.sentence_avg}
3
+ ddp_backend: ${params.distributed_training.ddp_backend}
fairseq-0.10.2/config/criterion/cross_entropy.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package _group_
2
+ sentence_avg: ${params.optimization.sentence_avg}
3
+ ddp_backend: ${params.distributed_training.ddp_backend}
fairseq-0.10.2/config/lr_scheduler/cosine.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ warmup_updates: 0
3
+ warmup_init_lr: -1
4
+ max_lr: 1.0
5
+ t_mult: 1.0
6
+ lr_period_updates: -1
7
+ lr_shrink: 0.1
fairseq-0.10.2/config/lr_scheduler/inverse_sqrt.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package _group_
2
+ warmup_updates: 4000
3
+ warmup_init_lr: -1
fairseq-0.10.2/config/model/transformer_lm.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.0
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 512
8
+ decoder_output_dim: 512
9
+ decoder_input_dim: 512
10
+ decoder_ffn_embed_dim: 2048
11
+ decoder_layers: 6
12
+ decoder_attention_heads: 8
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_baevski_gbw.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 512
8
+ decoder_output_dim: 512
9
+ decoder_input_dim: 512
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_baevski_wiki103.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.3
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.1
6
+ relu_dropout: 0.1
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 16
12
+ decoder_attention_heads: 8
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: "20000,60000"
16
+ adaptive_softmax_dropout: 0.2
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: true
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: "20000,60000"
27
+ tie_adaptive_weights: true
28
+ tie_adaptive_proj: true
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_big.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.0
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_gbw.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 512
8
+ decoder_output_dim: 512
9
+ decoder_input_dim: 512
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_gpt.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 768
8
+ decoder_output_dim: 768
9
+ decoder_input_dim: 768
10
+ decoder_ffn_embed_dim: 3072
11
+ decoder_layers: 12
12
+ decoder_attention_heads: 12
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_gpt2_big.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1600
8
+ decoder_output_dim: 1600
9
+ decoder_input_dim: 1600
10
+ decoder_ffn_embed_dim: 6400
11
+ decoder_layers: 48
12
+ decoder_attention_heads: 25
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_gpt2_medium.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1280
8
+ decoder_output_dim: 1280
9
+ decoder_input_dim: 1280
10
+ decoder_ffn_embed_dim: 5120
11
+ decoder_layers: 36
12
+ decoder_attention_heads: 20
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_gpt2_small.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "gelu"
3
+ dropout: 0.1
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.0
6
+ relu_dropout: 0.0
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 24
12
+ decoder_attention_heads: 16
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: false
15
+ adaptive_softmax_cutoff: null
16
+ adaptive_softmax_dropout: 0
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: false
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: null
27
+ tie_adaptive_weights: false
28
+ tie_adaptive_proj: false
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/model/transformer_lm_wiki103.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ activation_fn: "relu"
3
+ dropout: 0.3
4
+ attention_dropout: 0.1
5
+ activation_dropout: 0.1
6
+ relu_dropout: 0.1
7
+ decoder_embed_dim: 1024
8
+ decoder_output_dim: 1024
9
+ decoder_input_dim: 1024
10
+ decoder_ffn_embed_dim: 4096
11
+ decoder_layers: 16
12
+ decoder_attention_heads: 8
13
+ decoder_normalize_before: true
14
+ no_decoder_final_norm: true
15
+ adaptive_softmax_cutoff: "20000,60000"
16
+ adaptive_softmax_dropout: 0.2
17
+ adaptive_softmax_factor: 4
18
+ no_token_positional_embeddings: false
19
+ share_decoder_input_output_embed: false
20
+ character_embeddings: false
21
+ character_filters: "[(1, 64), (2, 128), (3, 192), (4, 256), (5, 256), (6, 256), (7, 256)]"
22
+ character_embedding_dim: 4
23
+ char_embedder_highway_layers: 2
24
+ adaptive_input: true
25
+ adaptive_input_factor: 4
26
+ adaptive_input_cutoff: "20000,60000"
27
+ tie_adaptive_weights: true
28
+ tie_adaptive_proj: true
29
+ decoder_learned_pos: false
30
+ decoder_layerdrop: 0
31
+ decoder_layers_to_keep: null
32
+ layernorm_embedding: false
33
+ no_scale_embedding: false
34
+ quant_noise_pq: 0
35
+ quant_noise_pq_block_size: 8
36
+ quant_noise_scalar: 0
fairseq-0.10.2/config/optimizer/adam.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # @package _group_
2
+ adam_betas: "(0.9, 0.999)"
3
+ adam_eps: 1.0e-8
4
+ weight_decay: 0
5
+ use_old_adam: false
fairseq-0.10.2/config/optimizer/nag.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # @package _group_
2
+ momentum: 0.99
3
+ weight_decay: 0.0
fairseq-0.10.2/config/params/eval_lm_params.yaml ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ common:
3
+ no_progress_bar: false
4
+ log_interval: 100
5
+ log_format: null
6
+ tensorboard_logdir: null
7
+ seed: 1
8
+ cpu: false
9
+ fp16: false
10
+ memory_efficient_fp16: false
11
+ fp16_no_flatten_grads: false
12
+ fp16_init_scale: 128
13
+ fp16_scale_window: null
14
+ fp16_scale_tolerance: 0.0
15
+ min_loss_scale: 1.0e-4
16
+ threshold_loss_scale: null
17
+ user_dir: null
18
+ empty_cache_freq: 0
19
+ all_gather_list_size: 16384
20
+ model_parallel_size: 1
21
+ checkpoint_suffix: ""
22
+ quantization_config_path: null
23
+ distributed_training:
24
+ distributed_rank: 0
25
+ distributed_backend: "nccl"
26
+ distributed_init_method: null
27
+ distributed_port: -1
28
+ device_id: 0
29
+ local_rank: 0
30
+ distributed_no_spawn: false
31
+ ddp_backend: "c10d"
32
+ bucket_cap_mb: 25
33
+ fix_batches_to_gpus: false
34
+ find_unused_parameters: false
35
+ fast_stat_sync: false
36
+ broadcast_buffers: false
37
+ distributed_wrapper: "DDP"
38
+ slowmo_momentum: null
39
+ slowmo_algorithm: "LocalSGD"
40
+ localsgd_frequency: 3
41
+ dataset:
42
+ num_workers: 1
43
+ skip_invalid_size_inputs_valid_test: false
44
+ max_tokens: null
45
+ batch_size: ${params.dataset.batch_size}
46
+ required_batch_size_multiple: 8
47
+ dataset_impl: null
48
+ data_buffer_size: 10
49
+ train_subset: "train"
50
+ valid_subset: "valid"
51
+ validate_interval: 1
52
+ fixed_validation_seed: null
53
+ disable_validation: false
54
+ curriculum: 0
55
+ gen_subset: "test"
56
+ num_shards: 1
57
+ shard_id: 0
58
+ max_tokens_valid: ${params.dataset.max_tokens}
59
+ batch_size_valid: ${params.dataset.batch_size}
60
+ optimization:
61
+ max_epoch: 0
62
+ max_update: 0
63
+ clip_norm: 25.0
64
+ sentence_avg: false
65
+ update_freq: [1]
66
+ lr: [0.25]
67
+ min_lr: -1.0
68
+ use_bmuf: false
69
+ checkpoint:
70
+ save_dir: "checkpoints"
71
+ restore_file: "checkpoint_last.pt"
72
+ reset_dataloader: false
73
+ reset_lr_scheduler: false
74
+ reset_meters: false
75
+ reset_optimizer: false
76
+ optimizer_overrides: "{}"
77
+ save_interval: 1
78
+ save_interval_updates: 0
79
+ keep_interval_updates: -1
80
+ keep_last_epochs: -1
81
+ keep_best_checkpoints: -1
82
+ no_save: false
83
+ no_epoch_checkpoints: false
84
+ no_last_checkpoints: false
85
+ no_save_optimizer_state: false
86
+ best_checkpoint_metric: "loss"
87
+ maximize_best_checkpoint_metric: false
88
+ patience: -1
89
+ common_eval:
90
+ path: null
91
+ remove_bpe: null
92
+ quiet: false
93
+ model_overrides: "{}"
94
+ results_path: null
95
+ eval_lm:
96
+ output_word_probs: false
97
+ output_word_stats: false
98
+ context_window: 0
99
+ bmuf:
100
+ block_lr: 1
101
+ block_momentum: 0.875
102
+ global_sync_iter: 50
103
+ warmup_iterations: 500
104
+ use_nbm: false
105
+ average_sync: false
fairseq-0.10.2/config/params/training_params.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ common:
3
+ no_progress_bar: false
4
+ log_interval: 100
5
+ log_format: null
6
+ tensorboard_logdir: null
7
+ seed: 1
8
+ cpu: false
9
+ fp16: false
10
+ memory_efficient_fp16: false
11
+ fp16_no_flatten_grads: false
12
+ fp16_init_scale: 128
13
+ fp16_scale_window: null
14
+ fp16_scale_tolerance: 0.0
15
+ min_loss_scale: 1.0e-4
16
+ threshold_loss_scale: null
17
+ user_dir: null
18
+ empty_cache_freq: 0
19
+ all_gather_list_size: 16384
20
+ model_parallel_size: 1
21
+ checkpoint_suffix: ""
22
+ quantization_config_path: null
23
+ distributed_training:
24
+ distributed_rank: 0
25
+ distributed_backend: "nccl"
26
+ distributed_init_method: null
27
+ distributed_port: -1
28
+ device_id: 0
29
+ local_rank: 0
30
+ distributed_no_spawn: false
31
+ ddp_backend: "c10d"
32
+ bucket_cap_mb: 25
33
+ fix_batches_to_gpus: false
34
+ find_unused_parameters: false
35
+ fast_stat_sync: false
36
+ broadcast_buffers: false
37
+ distributed_wrapper: "DDP"
38
+ slowmo_momentum: null
39
+ slowmo_algorithm: "LocalSGD"
40
+ localsgd_frequency: 3
41
+ dataset:
42
+ num_workers: 1
43
+ skip_invalid_size_inputs_valid_test: false
44
+ max_tokens: null
45
+ batch_size: ${params.dataset.batch_size}
46
+ required_batch_size_multiple: 8
47
+ dataset_impl: null
48
+ data_buffer_size: 10
49
+ train_subset: "train"
50
+ valid_subset: "valid"
51
+ validate_interval: 1
52
+ fixed_validation_seed: null
53
+ disable_validation: false
54
+ curriculum: 0
55
+ gen_subset: "test"
56
+ num_shards: 1
57
+ shard_id: 0
58
+ max_tokens_valid: ${params.dataset.max_tokens}
59
+ batch_size_valid: ${params.dataset.batch_size}
60
+ optimization:
61
+ max_epoch: 0
62
+ max_update: 0
63
+ clip_norm: 25.0
64
+ sentence_avg: false
65
+ update_freq: [1]
66
+ lr: [0.25]
67
+ min_lr: -1.0
68
+ use_bmuf: false
69
+ checkpoint:
70
+ save_dir: "checkpoints"
71
+ restore_file: "checkpoint_last.pt"
72
+ reset_dataloader: false
73
+ reset_lr_scheduler: false
74
+ reset_meters: false
75
+ reset_optimizer: false
76
+ optimizer_overrides: "{}"
77
+ save_interval: 1
78
+ save_interval_updates: 0
79
+ keep_interval_updates: -1
80
+ keep_last_epochs: -1
81
+ keep_best_checkpoints: -1
82
+ no_save: false
83
+ no_epoch_checkpoints: false
84
+ no_last_checkpoints: false
85
+ no_save_optimizer_state: false
86
+ best_checkpoint_metric: "loss"
87
+ maximize_best_checkpoint_metric: false
88
+ patience: -1
89
+ bmuf:
90
+ block_lr: 1
91
+ block_momentum: 0.875
92
+ global_sync_iter: 50
93
+ warmup_iterations: 500
94
+ use_nbm: false
95
+ average_sync: false
fairseq-0.10.2/config/task/language_modeling.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package _group_
2
+ data: ???
3
+ sample_break_mode: "none"
4
+ tokens_per_sample: 1024
5
+ output_dictionary_size: -1
6
+ self_target: false
7
+ future_target: false
8
+ past_target: false
9
+ add_bos_token: false
10
+ max_target_positions: null
fairseq-0.10.2/examples/noisychannel/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simple and Effective Noisy Channel Modeling for Neural Machine Translation (Yee et al., 2019)
2
+ This page contains pointers to pre-trained models as well as instructions on how to run the reranking scripts.
3
+
4
+ ## Citation:
5
+ ```bibtex
6
+ @inproceedings{yee2019simple,
7
+ title = {Simple and Effective Noisy Channel Modeling for Neural Machine Translation},
8
+ author = {Kyra Yee and Yann Dauphin and Michael Auli},
9
+ booktitle = {Conference on Empirical Methods in Natural Language Processing},
10
+ year = {2019},
11
+ }
12
+ ```
13
+
14
+ ## Pre-trained Models:
15
+
16
+ Model | Description | Download
17
+ ---|---|---
18
+ `transformer.noisychannel.de-en` | De->En Forward Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2)
19
+ `transformer.noisychannel.en-de` | En->De Channel Model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2)
20
+ `transformer_lm.noisychannel.en` | En Language model | [download (.tar.gz)](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2)
21
+
22
+ Test Data: [newstest_wmt17](https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2)
23
+
24
+ ## Example usage
25
+
26
+ ```
27
+ mkdir rerank_example
28
+ curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/forward_de2en.tar.bz2 | tar xvjf - -C rerank_example
29
+ curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/backward_en2de.tar.bz2 | tar xvjf - -C rerank_example
30
+ curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/reranking_en_lm.tar.bz2 | tar xvjf - -C rerank_example
31
+ curl https://dl.fbaipublicfiles.com/fairseq/models/noisychannel/wmt17test.tar.bz2 | tar xvjf - -C rerank_example
32
+
33
+ beam=50
34
+ num_trials=1000
35
+ fw_name=fw_model_ex
36
+ bw_name=bw_model_ex
37
+ lm_name=lm_ex
38
+ data_dir=rerank_example/hyphen-splitting-mixed-case-wmt17test-wmt14bpe
39
+ data_dir_name=wmt17
40
+ lm=rerank_example/lm/checkpoint_best.pt
41
+ lm_bpe_code=rerank_example/lm/bpe32k.code
42
+ lm_dict=rerank_example/lm/dict.txt
43
+ batch_size=32
44
+ bw=rerank_example/backward_en2de.pt
45
+ fw=rerank_example/forward_de2en.pt
46
+
47
+ # reranking with P(T|S) P(S|T) and P(T)
48
+ python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight1 weight3 \
49
+ --lower-bound 0 0 0 --upper-bound 3 3 3 --data-dir-name $data_dir_name \
50
+ --num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
51
+ -n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw \
52
+ --backwards1 --weight2 1 \
53
+ -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
54
+ --model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
55
+
56
+ # reranking with P(T|S) and P(T)
57
+ python examples/noisychannel/rerank_tune.py $data_dir --tune-param lenpen weight3 \
58
+ --lower-bound 0 0 --upper-bound 3 3 --data-dir-name $data_dir_name \
59
+ --num-trials $num_trials --source-lang de --target-lang en --gen-model $fw \
60
+ -n $beam --batch-size $batch_size --score-model1 $fw \
61
+ -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
62
+ --model1-name $fw_name --gen-model-name $fw_name
63
+
64
+ # to run with a preconfigured set of hyperparameters for the lenpen and model weights, using rerank.py instead.
65
+ python examples/noisychannel/rerank.py $data_dir \
66
+ --lenpen 0.269 --weight1 1 --weight2 0.929 --weight3 0.831 \
67
+ --data-dir-name $data_dir_name --source-lang de --target-lang en --gen-model $fw \
68
+ -n $beam --batch-size $batch_size --score-model2 $fw --score-model1 $bw --backwards1 \
69
+ -lm $lm --lm-dict $lm_dict --lm-name en_newscrawl --lm-bpe-code $lm_bpe_code \
70
+ --model2-name $fw_name --model1-name $bw_name --gen-model-name $fw_name
71
+ ```
72
+
fairseq-0.10.2/examples/noisychannel/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .rerank_options import * # noqa
fairseq-0.10.2/examples/noisychannel/rerank.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ from multiprocessing import Pool
8
+
9
+ import numpy as np
10
+ from fairseq import options
11
+ from fairseq.data import dictionary
12
+ from fairseq.scoring import bleu
13
+
14
+ from . import (
15
+ rerank_generate,
16
+ rerank_options,
17
+ rerank_score_bw,
18
+ rerank_score_lm,
19
+ rerank_utils,
20
+ )
21
+
22
+
23
+ def score_target_hypo(
24
+ args, a, b, c, lenpen, target_outfile, hypo_outfile, write_hypos, normalize
25
+ ):
26
+
27
+ print("lenpen", lenpen, "weight1", a, "weight2", b, "weight3", c)
28
+ gen_output_lst, bitext1_lst, bitext2_lst, lm_res_lst = load_score_files(args)
29
+ dict = dictionary.Dictionary()
30
+ scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
31
+
32
+ ordered_hypos = {}
33
+ ordered_targets = {}
34
+
35
+ for shard_id in range(len(bitext1_lst)):
36
+ bitext1 = bitext1_lst[shard_id]
37
+ bitext2 = bitext2_lst[shard_id]
38
+ gen_output = gen_output_lst[shard_id]
39
+ lm_res = lm_res_lst[shard_id]
40
+
41
+ total = len(bitext1.rescore_source.keys())
42
+ source_lst = []
43
+ hypo_lst = []
44
+ score_lst = []
45
+ reference_lst = []
46
+ j = 1
47
+ best_score = -math.inf
48
+
49
+ for i in range(total):
50
+ # length is measured in terms of words, not bpe tokens, since models may not share the same bpe
51
+ target_len = len(bitext1.rescore_hypo[i].split())
52
+
53
+ if lm_res is not None:
54
+ lm_score = lm_res.score[i]
55
+ else:
56
+ lm_score = 0
57
+
58
+ if bitext2 is not None:
59
+ bitext2_score = bitext2.rescore_score[i]
60
+ bitext2_backwards = bitext2.backwards
61
+ else:
62
+ bitext2_score = None
63
+ bitext2_backwards = None
64
+
65
+ score = rerank_utils.get_score(
66
+ a,
67
+ b,
68
+ c,
69
+ target_len,
70
+ bitext1.rescore_score[i],
71
+ bitext2_score,
72
+ lm_score=lm_score,
73
+ lenpen=lenpen,
74
+ src_len=bitext1.source_lengths[i],
75
+ tgt_len=bitext1.target_lengths[i],
76
+ bitext1_backwards=bitext1.backwards,
77
+ bitext2_backwards=bitext2_backwards,
78
+ normalize=normalize,
79
+ )
80
+
81
+ if score > best_score:
82
+ best_score = score
83
+ best_hypo = bitext1.rescore_hypo[i]
84
+
85
+ if j == gen_output.num_hypos[i] or j == args.num_rescore:
86
+ j = 1
87
+ hypo_lst.append(best_hypo)
88
+ score_lst.append(best_score)
89
+ source_lst.append(bitext1.rescore_source[i])
90
+ reference_lst.append(bitext1.rescore_target[i])
91
+
92
+ best_score = -math.inf
93
+ best_hypo = ""
94
+ else:
95
+ j += 1
96
+
97
+ gen_keys = list(sorted(gen_output.no_bpe_target.keys()))
98
+
99
+ for key in range(len(gen_keys)):
100
+ if args.prefix_len is None:
101
+ assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
102
+ "pred and rescore hypo mismatch: i: "
103
+ + str(key)
104
+ + ", "
105
+ + str(hypo_lst[key])
106
+ + str(gen_keys[key])
107
+ + str(gen_output.no_bpe_hypo[key])
108
+ )
109
+ sys_tok = dict.encode_line(hypo_lst[key])
110
+ ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
111
+ scorer.add(ref_tok, sys_tok)
112
+
113
+ else:
114
+ full_hypo = rerank_utils.get_full_from_prefix(
115
+ hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
116
+ )
117
+ sys_tok = dict.encode_line(full_hypo)
118
+ ref_tok = dict.encode_line(gen_output.no_bpe_target[gen_keys[key]])
119
+ scorer.add(ref_tok, sys_tok)
120
+
121
+ # if only one set of hyper parameters is provided, write the predictions to a file
122
+ if write_hypos:
123
+ # recover the orinal ids from n best list generation
124
+ for key in range(len(gen_output.no_bpe_target)):
125
+ if args.prefix_len is None:
126
+ assert hypo_lst[key] in gen_output.no_bpe_hypo[gen_keys[key]], (
127
+ "pred and rescore hypo mismatch:"
128
+ + "i:"
129
+ + str(key)
130
+ + str(hypo_lst[key])
131
+ + str(gen_output.no_bpe_hypo[key])
132
+ )
133
+ ordered_hypos[gen_keys[key]] = hypo_lst[key]
134
+ ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
135
+ gen_keys[key]
136
+ ]
137
+
138
+ else:
139
+ full_hypo = rerank_utils.get_full_from_prefix(
140
+ hypo_lst[key], gen_output.no_bpe_hypo[gen_keys[key]]
141
+ )
142
+ ordered_hypos[gen_keys[key]] = full_hypo
143
+ ordered_targets[gen_keys[key]] = gen_output.no_bpe_target[
144
+ gen_keys[key]
145
+ ]
146
+
147
+ # write the hypos in the original order from nbest list generation
148
+ if args.num_shards == (len(bitext1_lst)):
149
+ with open(target_outfile, "w") as t:
150
+ with open(hypo_outfile, "w") as h:
151
+ for key in range(len(ordered_hypos)):
152
+ t.write(ordered_targets[key])
153
+ h.write(ordered_hypos[key])
154
+
155
+ res = scorer.result_string(4)
156
+ if write_hypos:
157
+ print(res)
158
+ score = rerank_utils.parse_bleu_scoring(res)
159
+ return score
160
+
161
+
162
+ def match_target_hypo(args, target_outfile, hypo_outfile):
163
+ """combine scores from the LM and bitext models, and write the top scoring hypothesis to a file"""
164
+ if len(args.weight1) == 1:
165
+ res = score_target_hypo(
166
+ args,
167
+ args.weight1[0],
168
+ args.weight2[0],
169
+ args.weight3[0],
170
+ args.lenpen[0],
171
+ target_outfile,
172
+ hypo_outfile,
173
+ True,
174
+ args.normalize,
175
+ )
176
+ rerank_scores = [res]
177
+ else:
178
+ print("launching pool")
179
+ with Pool(32) as p:
180
+ rerank_scores = p.starmap(
181
+ score_target_hypo,
182
+ [
183
+ (
184
+ args,
185
+ args.weight1[i],
186
+ args.weight2[i],
187
+ args.weight3[i],
188
+ args.lenpen[i],
189
+ target_outfile,
190
+ hypo_outfile,
191
+ False,
192
+ args.normalize,
193
+ )
194
+ for i in range(len(args.weight1))
195
+ ],
196
+ )
197
+
198
+ if len(rerank_scores) > 1:
199
+ best_index = np.argmax(rerank_scores)
200
+ best_score = rerank_scores[best_index]
201
+ print("best score", best_score)
202
+ print("best lenpen", args.lenpen[best_index])
203
+ print("best weight1", args.weight1[best_index])
204
+ print("best weight2", args.weight2[best_index])
205
+ print("best weight3", args.weight3[best_index])
206
+ return (
207
+ args.lenpen[best_index],
208
+ args.weight1[best_index],
209
+ args.weight2[best_index],
210
+ args.weight3[best_index],
211
+ best_score,
212
+ )
213
+
214
+ else:
215
+ return (
216
+ args.lenpen[0],
217
+ args.weight1[0],
218
+ args.weight2[0],
219
+ args.weight3[0],
220
+ rerank_scores[0],
221
+ )
222
+
223
+
224
+ def load_score_files(args):
225
+ if args.all_shards:
226
+ shard_ids = list(range(args.num_shards))
227
+ else:
228
+ shard_ids = [args.shard_id]
229
+
230
+ gen_output_lst = []
231
+ bitext1_lst = []
232
+ bitext2_lst = []
233
+ lm_res1_lst = []
234
+
235
+ for shard_id in shard_ids:
236
+ using_nbest = args.nbest_list is not None
237
+ (
238
+ pre_gen,
239
+ left_to_right_preprocessed_dir,
240
+ right_to_left_preprocessed_dir,
241
+ backwards_preprocessed_dir,
242
+ lm_preprocessed_dir,
243
+ ) = rerank_utils.get_directories(
244
+ args.data_dir_name,
245
+ args.num_rescore,
246
+ args.gen_subset,
247
+ args.gen_model_name,
248
+ shard_id,
249
+ args.num_shards,
250
+ args.sampling,
251
+ args.prefix_len,
252
+ args.target_prefix_frac,
253
+ args.source_prefix_frac,
254
+ )
255
+
256
+ rerank1_is_gen = (
257
+ args.gen_model == args.score_model1 and args.source_prefix_frac is None
258
+ )
259
+ rerank2_is_gen = (
260
+ args.gen_model == args.score_model2 and args.source_prefix_frac is None
261
+ )
262
+
263
+ score1_file = rerank_utils.rescore_file_name(
264
+ pre_gen,
265
+ args.prefix_len,
266
+ args.model1_name,
267
+ target_prefix_frac=args.target_prefix_frac,
268
+ source_prefix_frac=args.source_prefix_frac,
269
+ backwards=args.backwards1,
270
+ )
271
+ if args.score_model2 is not None:
272
+ score2_file = rerank_utils.rescore_file_name(
273
+ pre_gen,
274
+ args.prefix_len,
275
+ args.model2_name,
276
+ target_prefix_frac=args.target_prefix_frac,
277
+ source_prefix_frac=args.source_prefix_frac,
278
+ backwards=args.backwards2,
279
+ )
280
+ if args.language_model is not None:
281
+ lm_score_file = rerank_utils.rescore_file_name(
282
+ pre_gen, args.prefix_len, args.lm_name, lm_file=True
283
+ )
284
+
285
+ # get gen output
286
+ predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
287
+ if using_nbest:
288
+ print("Using predefined n-best list from interactive.py")
289
+ predictions_bpe_file = args.nbest_list
290
+ gen_output = rerank_utils.BitextOutputFromGen(
291
+ predictions_bpe_file,
292
+ bpe_symbol=args.remove_bpe,
293
+ nbest=using_nbest,
294
+ prefix_len=args.prefix_len,
295
+ target_prefix_frac=args.target_prefix_frac,
296
+ )
297
+
298
+ if rerank1_is_gen:
299
+ bitext1 = gen_output
300
+ else:
301
+ bitext1 = rerank_utils.BitextOutput(
302
+ score1_file,
303
+ args.backwards1,
304
+ args.right_to_left1,
305
+ args.remove_bpe,
306
+ args.prefix_len,
307
+ args.target_prefix_frac,
308
+ args.source_prefix_frac,
309
+ )
310
+
311
+ if args.score_model2 is not None or args.nbest_list is not None:
312
+ if rerank2_is_gen:
313
+ bitext2 = gen_output
314
+ else:
315
+ bitext2 = rerank_utils.BitextOutput(
316
+ score2_file,
317
+ args.backwards2,
318
+ args.right_to_left2,
319
+ args.remove_bpe,
320
+ args.prefix_len,
321
+ args.target_prefix_frac,
322
+ args.source_prefix_frac,
323
+ )
324
+
325
+ assert (
326
+ bitext2.source_lengths == bitext1.source_lengths
327
+ ), "source lengths for rescoring models do not match"
328
+ assert (
329
+ bitext2.target_lengths == bitext1.target_lengths
330
+ ), "target lengths for rescoring models do not match"
331
+ else:
332
+ if args.diff_bpe:
333
+ assert args.score_model2 is None
334
+ bitext2 = gen_output
335
+ else:
336
+ bitext2 = None
337
+
338
+ if args.language_model is not None:
339
+ lm_res1 = rerank_utils.LMOutput(
340
+ lm_score_file,
341
+ args.lm_dict,
342
+ args.prefix_len,
343
+ args.remove_bpe,
344
+ args.target_prefix_frac,
345
+ )
346
+ else:
347
+ lm_res1 = None
348
+
349
+ gen_output_lst.append(gen_output)
350
+ bitext1_lst.append(bitext1)
351
+ bitext2_lst.append(bitext2)
352
+ lm_res1_lst.append(lm_res1)
353
+ return gen_output_lst, bitext1_lst, bitext2_lst, lm_res1_lst
354
+
355
+
356
+ def rerank(args):
357
+ if type(args.lenpen) is not list:
358
+ args.lenpen = [args.lenpen]
359
+ if type(args.weight1) is not list:
360
+ args.weight1 = [args.weight1]
361
+ if type(args.weight2) is not list:
362
+ args.weight2 = [args.weight2]
363
+ if type(args.weight3) is not list:
364
+ args.weight3 = [args.weight3]
365
+ if args.all_shards:
366
+ shard_ids = list(range(args.num_shards))
367
+ else:
368
+ shard_ids = [args.shard_id]
369
+
370
+ for shard_id in shard_ids:
371
+ (
372
+ pre_gen,
373
+ left_to_right_preprocessed_dir,
374
+ right_to_left_preprocessed_dir,
375
+ backwards_preprocessed_dir,
376
+ lm_preprocessed_dir,
377
+ ) = rerank_utils.get_directories(
378
+ args.data_dir_name,
379
+ args.num_rescore,
380
+ args.gen_subset,
381
+ args.gen_model_name,
382
+ shard_id,
383
+ args.num_shards,
384
+ args.sampling,
385
+ args.prefix_len,
386
+ args.target_prefix_frac,
387
+ args.source_prefix_frac,
388
+ )
389
+ rerank_generate.gen_and_reprocess_nbest(args)
390
+ rerank_score_bw.score_bw(args)
391
+ rerank_score_lm.score_lm(args)
392
+
393
+ if args.write_hypos is None:
394
+ write_targets = pre_gen + "/matched_targets"
395
+ write_hypos = pre_gen + "/matched_hypos"
396
+ else:
397
+ write_targets = args.write_hypos + "_targets" + args.gen_subset
398
+ write_hypos = args.write_hypos + "_hypos" + args.gen_subset
399
+
400
+ if args.all_shards:
401
+ write_targets += "_all_shards"
402
+ write_hypos += "_all_shards"
403
+
404
+ (
405
+ best_lenpen,
406
+ best_weight1,
407
+ best_weight2,
408
+ best_weight3,
409
+ best_score,
410
+ ) = match_target_hypo(args, write_targets, write_hypos)
411
+
412
+ return best_lenpen, best_weight1, best_weight2, best_weight3, best_score
413
+
414
+
415
+ def cli_main():
416
+ parser = rerank_options.get_reranking_parser()
417
+ args = options.parse_args_and_arch(parser)
418
+ rerank(args)
419
+
420
+
421
+ if __name__ == "__main__":
422
+ cli_main()
fairseq-0.10.2/examples/noisychannel/rerank_generate.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+ # Copyright (c) Facebook, Inc. and its affiliates.
3
+ #
4
+ # This source code is licensed under the MIT license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Generate n-best translations using a trained model.
9
+ """
10
+
11
+ import os
12
+ import subprocess
13
+ from contextlib import redirect_stdout
14
+
15
+ from fairseq import options
16
+ from fairseq_cli import generate, preprocess
17
+
18
+ from . import rerank_options, rerank_utils
19
+
20
+
21
+ def gen_and_reprocess_nbest(args):
22
+ if args.score_dict_dir is None:
23
+ args.score_dict_dir = args.data
24
+ if args.prefix_len is not None:
25
+ assert (
26
+ args.right_to_left1 is False
27
+ ), "prefix length not compatible with right to left models"
28
+ assert (
29
+ args.right_to_left2 is False
30
+ ), "prefix length not compatible with right to left models"
31
+
32
+ if args.nbest_list is not None:
33
+ assert args.score_model2 is None
34
+
35
+ if args.backwards1:
36
+ scorer1_src = args.target_lang
37
+ scorer1_tgt = args.source_lang
38
+ else:
39
+ scorer1_src = args.source_lang
40
+ scorer1_tgt = args.target_lang
41
+
42
+ store_data = (
43
+ os.path.join(os.path.dirname(__file__)) + "/rerank_data/" + args.data_dir_name
44
+ )
45
+ if not os.path.exists(store_data):
46
+ os.makedirs(store_data)
47
+
48
+ (
49
+ pre_gen,
50
+ left_to_right_preprocessed_dir,
51
+ right_to_left_preprocessed_dir,
52
+ backwards_preprocessed_dir,
53
+ lm_preprocessed_dir,
54
+ ) = rerank_utils.get_directories(
55
+ args.data_dir_name,
56
+ args.num_rescore,
57
+ args.gen_subset,
58
+ args.gen_model_name,
59
+ args.shard_id,
60
+ args.num_shards,
61
+ args.sampling,
62
+ args.prefix_len,
63
+ args.target_prefix_frac,
64
+ args.source_prefix_frac,
65
+ )
66
+ assert not (
67
+ args.right_to_left1 and args.backwards1
68
+ ), "backwards right to left not supported"
69
+ assert not (
70
+ args.right_to_left2 and args.backwards2
71
+ ), "backwards right to left not supported"
72
+ assert not (
73
+ args.prefix_len is not None and args.target_prefix_frac is not None
74
+ ), "target prefix frac and target prefix len incompatible"
75
+
76
+ # make directory to store generation results
77
+ if not os.path.exists(pre_gen):
78
+ os.makedirs(pre_gen)
79
+
80
+ rerank1_is_gen = (
81
+ args.gen_model == args.score_model1 and args.source_prefix_frac is None
82
+ )
83
+ rerank2_is_gen = (
84
+ args.gen_model == args.score_model2 and args.source_prefix_frac is None
85
+ )
86
+
87
+ if args.nbest_list is not None:
88
+ rerank2_is_gen = True
89
+
90
+ # make directories to store preprossed nbest list for reranking
91
+ if not os.path.exists(left_to_right_preprocessed_dir):
92
+ os.makedirs(left_to_right_preprocessed_dir)
93
+ if not os.path.exists(right_to_left_preprocessed_dir):
94
+ os.makedirs(right_to_left_preprocessed_dir)
95
+ if not os.path.exists(lm_preprocessed_dir):
96
+ os.makedirs(lm_preprocessed_dir)
97
+ if not os.path.exists(backwards_preprocessed_dir):
98
+ os.makedirs(backwards_preprocessed_dir)
99
+
100
+ score1_file = rerank_utils.rescore_file_name(
101
+ pre_gen,
102
+ args.prefix_len,
103
+ args.model1_name,
104
+ target_prefix_frac=args.target_prefix_frac,
105
+ source_prefix_frac=args.source_prefix_frac,
106
+ backwards=args.backwards1,
107
+ )
108
+ if args.score_model2 is not None:
109
+ score2_file = rerank_utils.rescore_file_name(
110
+ pre_gen,
111
+ args.prefix_len,
112
+ args.model2_name,
113
+ target_prefix_frac=args.target_prefix_frac,
114
+ source_prefix_frac=args.source_prefix_frac,
115
+ backwards=args.backwards2,
116
+ )
117
+
118
+ predictions_bpe_file = pre_gen + "/generate_output_bpe.txt"
119
+
120
+ using_nbest = args.nbest_list is not None
121
+
122
+ if using_nbest:
123
+ print("Using predefined n-best list from interactive.py")
124
+ predictions_bpe_file = args.nbest_list
125
+
126
+ else:
127
+ if not os.path.isfile(predictions_bpe_file):
128
+ print("STEP 1: generate predictions using the p(T|S) model with bpe")
129
+ print(args.data)
130
+ param1 = [
131
+ args.data,
132
+ "--path",
133
+ args.gen_model,
134
+ "--shard-id",
135
+ str(args.shard_id),
136
+ "--num-shards",
137
+ str(args.num_shards),
138
+ "--nbest",
139
+ str(args.num_rescore),
140
+ "--batch-size",
141
+ str(args.batch_size),
142
+ "--beam",
143
+ str(args.num_rescore),
144
+ "--batch-size",
145
+ str(args.num_rescore),
146
+ "--gen-subset",
147
+ args.gen_subset,
148
+ "--source-lang",
149
+ args.source_lang,
150
+ "--target-lang",
151
+ args.target_lang,
152
+ ]
153
+ if args.sampling:
154
+ param1 += ["--sampling"]
155
+
156
+ gen_parser = options.get_generation_parser()
157
+ input_args = options.parse_args_and_arch(gen_parser, param1)
158
+
159
+ print(input_args)
160
+ with open(predictions_bpe_file, "w") as f:
161
+ with redirect_stdout(f):
162
+ generate.main(input_args)
163
+
164
+ gen_output = rerank_utils.BitextOutputFromGen(
165
+ predictions_bpe_file,
166
+ bpe_symbol=args.remove_bpe,
167
+ nbest=using_nbest,
168
+ prefix_len=args.prefix_len,
169
+ target_prefix_frac=args.target_prefix_frac,
170
+ )
171
+
172
+ if args.diff_bpe:
173
+ rerank_utils.write_reprocessed(
174
+ gen_output.no_bpe_source,
175
+ gen_output.no_bpe_hypo,
176
+ gen_output.no_bpe_target,
177
+ pre_gen + "/source_gen_bpe." + args.source_lang,
178
+ pre_gen + "/target_gen_bpe." + args.target_lang,
179
+ pre_gen + "/reference_gen_bpe." + args.target_lang,
180
+ )
181
+ bitext_bpe = args.rescore_bpe_code
182
+ bpe_src_param = [
183
+ "-c",
184
+ bitext_bpe,
185
+ "--input",
186
+ pre_gen + "/source_gen_bpe." + args.source_lang,
187
+ "--output",
188
+ pre_gen + "/rescore_data." + args.source_lang,
189
+ ]
190
+ bpe_tgt_param = [
191
+ "-c",
192
+ bitext_bpe,
193
+ "--input",
194
+ pre_gen + "/target_gen_bpe." + args.target_lang,
195
+ "--output",
196
+ pre_gen + "/rescore_data." + args.target_lang,
197
+ ]
198
+
199
+ subprocess.call(
200
+ [
201
+ "python",
202
+ os.path.join(
203
+ os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
204
+ ),
205
+ ]
206
+ + bpe_src_param,
207
+ shell=False,
208
+ )
209
+
210
+ subprocess.call(
211
+ [
212
+ "python",
213
+ os.path.join(
214
+ os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
215
+ ),
216
+ ]
217
+ + bpe_tgt_param,
218
+ shell=False,
219
+ )
220
+
221
+ if (not os.path.isfile(score1_file) and not rerank1_is_gen) or (
222
+ args.score_model2 is not None
223
+ and not os.path.isfile(score2_file)
224
+ and not rerank2_is_gen
225
+ ):
226
+ print(
227
+ "STEP 2: process the output of generate.py so we have clean text files with the translations"
228
+ )
229
+
230
+ rescore_file = "/rescore_data"
231
+ if args.prefix_len is not None:
232
+ prefix_len_rescore_file = rescore_file + "prefix" + str(args.prefix_len)
233
+ if args.target_prefix_frac is not None:
234
+ target_prefix_frac_rescore_file = (
235
+ rescore_file + "target_prefix_frac" + str(args.target_prefix_frac)
236
+ )
237
+ if args.source_prefix_frac is not None:
238
+ source_prefix_frac_rescore_file = (
239
+ rescore_file + "source_prefix_frac" + str(args.source_prefix_frac)
240
+ )
241
+
242
+ if not args.right_to_left1 or not args.right_to_left2:
243
+ if not args.diff_bpe:
244
+ rerank_utils.write_reprocessed(
245
+ gen_output.source,
246
+ gen_output.hypo,
247
+ gen_output.target,
248
+ pre_gen + rescore_file + "." + args.source_lang,
249
+ pre_gen + rescore_file + "." + args.target_lang,
250
+ pre_gen + "/reference_file",
251
+ bpe_symbol=args.remove_bpe,
252
+ )
253
+ if args.prefix_len is not None:
254
+ bw_rescore_file = prefix_len_rescore_file
255
+ rerank_utils.write_reprocessed(
256
+ gen_output.source,
257
+ gen_output.hypo,
258
+ gen_output.target,
259
+ pre_gen + prefix_len_rescore_file + "." + args.source_lang,
260
+ pre_gen + prefix_len_rescore_file + "." + args.target_lang,
261
+ pre_gen + "/reference_file",
262
+ prefix_len=args.prefix_len,
263
+ bpe_symbol=args.remove_bpe,
264
+ )
265
+ elif args.target_prefix_frac is not None:
266
+ bw_rescore_file = target_prefix_frac_rescore_file
267
+ rerank_utils.write_reprocessed(
268
+ gen_output.source,
269
+ gen_output.hypo,
270
+ gen_output.target,
271
+ pre_gen
272
+ + target_prefix_frac_rescore_file
273
+ + "."
274
+ + args.source_lang,
275
+ pre_gen
276
+ + target_prefix_frac_rescore_file
277
+ + "."
278
+ + args.target_lang,
279
+ pre_gen + "/reference_file",
280
+ bpe_symbol=args.remove_bpe,
281
+ target_prefix_frac=args.target_prefix_frac,
282
+ )
283
+ else:
284
+ bw_rescore_file = rescore_file
285
+
286
+ if args.source_prefix_frac is not None:
287
+ fw_rescore_file = source_prefix_frac_rescore_file
288
+ rerank_utils.write_reprocessed(
289
+ gen_output.source,
290
+ gen_output.hypo,
291
+ gen_output.target,
292
+ pre_gen
293
+ + source_prefix_frac_rescore_file
294
+ + "."
295
+ + args.source_lang,
296
+ pre_gen
297
+ + source_prefix_frac_rescore_file
298
+ + "."
299
+ + args.target_lang,
300
+ pre_gen + "/reference_file",
301
+ bpe_symbol=args.remove_bpe,
302
+ source_prefix_frac=args.source_prefix_frac,
303
+ )
304
+ else:
305
+ fw_rescore_file = rescore_file
306
+
307
+ if args.right_to_left1 or args.right_to_left2:
308
+ rerank_utils.write_reprocessed(
309
+ gen_output.source,
310
+ gen_output.hypo,
311
+ gen_output.target,
312
+ pre_gen + "/right_to_left_rescore_data." + args.source_lang,
313
+ pre_gen + "/right_to_left_rescore_data." + args.target_lang,
314
+ pre_gen + "/right_to_left_reference_file",
315
+ right_to_left=True,
316
+ bpe_symbol=args.remove_bpe,
317
+ )
318
+
319
+ print("STEP 3: binarize the translations")
320
+ if (
321
+ not args.right_to_left1
322
+ or args.score_model2 is not None
323
+ and not args.right_to_left2
324
+ or not rerank1_is_gen
325
+ ):
326
+
327
+ if args.backwards1 or args.backwards2:
328
+ if args.backwards_score_dict_dir is not None:
329
+ bw_dict = args.backwards_score_dict_dir
330
+ else:
331
+ bw_dict = args.score_dict_dir
332
+ bw_preprocess_param = [
333
+ "--source-lang",
334
+ scorer1_src,
335
+ "--target-lang",
336
+ scorer1_tgt,
337
+ "--trainpref",
338
+ pre_gen + bw_rescore_file,
339
+ "--srcdict",
340
+ bw_dict + "/dict." + scorer1_src + ".txt",
341
+ "--tgtdict",
342
+ bw_dict + "/dict." + scorer1_tgt + ".txt",
343
+ "--destdir",
344
+ backwards_preprocessed_dir,
345
+ ]
346
+ preprocess_parser = options.get_preprocessing_parser()
347
+ input_args = preprocess_parser.parse_args(bw_preprocess_param)
348
+ preprocess.main(input_args)
349
+
350
+ preprocess_param = [
351
+ "--source-lang",
352
+ scorer1_src,
353
+ "--target-lang",
354
+ scorer1_tgt,
355
+ "--trainpref",
356
+ pre_gen + fw_rescore_file,
357
+ "--srcdict",
358
+ args.score_dict_dir + "/dict." + scorer1_src + ".txt",
359
+ "--tgtdict",
360
+ args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
361
+ "--destdir",
362
+ left_to_right_preprocessed_dir,
363
+ ]
364
+ preprocess_parser = options.get_preprocessing_parser()
365
+ input_args = preprocess_parser.parse_args(preprocess_param)
366
+ preprocess.main(input_args)
367
+
368
+ if args.right_to_left1 or args.right_to_left2:
369
+ preprocess_param = [
370
+ "--source-lang",
371
+ scorer1_src,
372
+ "--target-lang",
373
+ scorer1_tgt,
374
+ "--trainpref",
375
+ pre_gen + "/right_to_left_rescore_data",
376
+ "--srcdict",
377
+ args.score_dict_dir + "/dict." + scorer1_src + ".txt",
378
+ "--tgtdict",
379
+ args.score_dict_dir + "/dict." + scorer1_tgt + ".txt",
380
+ "--destdir",
381
+ right_to_left_preprocessed_dir,
382
+ ]
383
+ preprocess_parser = options.get_preprocessing_parser()
384
+ input_args = preprocess_parser.parse_args(preprocess_param)
385
+ preprocess.main(input_args)
386
+
387
+ return gen_output
388
+
389
+
390
+ def cli_main():
391
+ parser = rerank_options.get_reranking_parser()
392
+ args = options.parse_args_and_arch(parser)
393
+ gen_and_reprocess_nbest(args)
394
+
395
+
396
+ if __name__ == "__main__":
397
+ cli_main()
fairseq-0.10.2/examples/noisychannel/rerank_options.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from fairseq import options
7
+
8
+
9
+ def get_reranking_parser(default_task="translation"):
10
+ parser = options.get_parser("Generation and reranking", default_task)
11
+ add_reranking_args(parser)
12
+ return parser
13
+
14
+
15
+ def get_tuning_parser(default_task="translation"):
16
+ parser = options.get_parser("Reranking tuning", default_task)
17
+ add_reranking_args(parser)
18
+ add_tuning_args(parser)
19
+ return parser
20
+
21
+
22
+ def add_reranking_args(parser):
23
+ group = parser.add_argument_group("Reranking")
24
+ # fmt: off
25
+ group.add_argument('--score-model1', '-s1', type=str, metavar='FILE', required=True,
26
+ help='path to first model or ensemble of models for rescoring')
27
+ group.add_argument('--score-model2', '-s2', type=str, metavar='FILE', required=False,
28
+ help='path to second model or ensemble of models for rescoring')
29
+ group.add_argument('--num-rescore', '-n', type=int, metavar='N', default=10,
30
+ help='the number of candidate hypothesis to rescore')
31
+ group.add_argument('-bz', '--batch-size', type=int, metavar='N', default=128,
32
+ help='batch size for generating the nbest list')
33
+ group.add_argument('--gen-subset', default='test', metavar='SET', choices=['test', 'train', 'valid'],
34
+ help='data subset to generate (train, valid, test)')
35
+ group.add_argument('--gen-model', default=None, metavar='FILE',
36
+ help='the model to generate translations')
37
+ group.add_argument('-b1', '--backwards1', action='store_true',
38
+ help='whether or not the first model group is backwards')
39
+ group.add_argument('-b2', '--backwards2', action='store_true',
40
+ help='whether or not the second model group is backwards')
41
+ group.add_argument('-a', '--weight1', default=1, nargs='+', type=float,
42
+ help='the weight(s) of the first model')
43
+ group.add_argument('-b', '--weight2', default=1, nargs='+', type=float,
44
+ help='the weight(s) of the second model, or the gen model if using nbest from interactive.py')
45
+ group.add_argument('-c', '--weight3', default=1, nargs='+', type=float,
46
+ help='the weight(s) of the third model')
47
+
48
+ # lm arguments
49
+ group.add_argument('-lm', '--language-model', default=None, metavar='FILE',
50
+ help='language model for target language to rescore translations')
51
+ group.add_argument('--lm-dict', default=None, metavar='FILE',
52
+ help='the dict of the language model for the target language')
53
+ group.add_argument('--lm-name', default=None,
54
+ help='the name of the language model for the target language')
55
+ group.add_argument('--lm-bpe-code', default=None, metavar='FILE',
56
+ help='the bpe code for the language model for the target language')
57
+ group.add_argument('--data-dir-name', default=None,
58
+ help='name of data directory')
59
+ group.add_argument('--lenpen', default=1, nargs='+', type=float,
60
+ help='length penalty: <1.0 favors shorter, >1.0 favors longer sentences')
61
+ group.add_argument('--score-dict-dir', default=None,
62
+ help='the directory with dictionaries for the scoring models')
63
+ group.add_argument('--right-to-left1', action='store_true',
64
+ help='whether the first model group is a right to left model')
65
+ group.add_argument('--right-to-left2', action='store_true',
66
+ help='whether the second model group is a right to left model')
67
+ group.add_argument('--remove-bpe', '--post-process', default='@@ ',
68
+ help='the bpe symbol, used for the bitext and LM')
69
+ group.add_argument('--prefix-len', default=None, type=int,
70
+ help='the length of the target prefix to use in rescoring (in terms of words wo bpe)')
71
+ group.add_argument('--sampling', action='store_true',
72
+ help='use sampling instead of beam search for generating n best list')
73
+ group.add_argument('--diff-bpe', action='store_true',
74
+ help='bpe for rescoring and nbest list not the same')
75
+ group.add_argument('--rescore-bpe-code', default=None,
76
+ help='bpe code for rescoring models')
77
+ group.add_argument('--nbest-list', default=None,
78
+ help='use predefined nbest list in interactive.py format')
79
+ group.add_argument('--write-hypos', default=None,
80
+ help='filename prefix to write hypos to')
81
+ group.add_argument('--ref-translation', default=None,
82
+ help='reference translation to use with nbest list from interactive.py')
83
+ group.add_argument('--backwards-score-dict-dir', default=None,
84
+ help='the directory with dictionaries for the backwards model,'
85
+ 'if None then it is assumed the fw and backwards models share dictionaries')
86
+
87
+ # extra scaling args
88
+ group.add_argument('--gen-model-name', default=None,
89
+ help='the name of the models that generated the nbest list')
90
+ group.add_argument('--model1-name', default=None,
91
+ help='the name of the set for model1 group ')
92
+ group.add_argument('--model2-name', default=None,
93
+ help='the name of the set for model2 group')
94
+ group.add_argument('--shard-id', default=0, type=int,
95
+ help='the id of the shard to generate')
96
+ group.add_argument('--num-shards', default=1, type=int,
97
+ help='the number of shards to generate across')
98
+ group.add_argument('--all-shards', action='store_true',
99
+ help='use all shards')
100
+ group.add_argument('--target-prefix-frac', default=None, type=float,
101
+ help='the fraction of the target prefix to use in rescoring (in terms of words wo bpe)')
102
+ group.add_argument('--source-prefix-frac', default=None, type=float,
103
+ help='the fraction of the source prefix to use in rescoring (in terms of words wo bpe)')
104
+ group.add_argument('--normalize', action='store_true',
105
+ help='whether to normalize by src and target len')
106
+ # fmt: on
107
+ return group
108
+
109
+
110
+ def add_tuning_args(parser):
111
+ group = parser.add_argument_group("Tuning")
112
+
113
+ group.add_argument(
114
+ "--lower-bound",
115
+ default=[-0.7],
116
+ nargs="+",
117
+ type=float,
118
+ help="lower bound of search space",
119
+ )
120
+ group.add_argument(
121
+ "--upper-bound",
122
+ default=[3],
123
+ nargs="+",
124
+ type=float,
125
+ help="upper bound of search space",
126
+ )
127
+ group.add_argument(
128
+ "--tune-param",
129
+ default=["lenpen"],
130
+ nargs="+",
131
+ choices=["lenpen", "weight1", "weight2", "weight3"],
132
+ help="the parameter(s) to tune",
133
+ )
134
+ group.add_argument(
135
+ "--tune-subset",
136
+ default="valid",
137
+ choices=["valid", "test", "train"],
138
+ help="the subset to tune on ",
139
+ )
140
+ group.add_argument(
141
+ "--num-trials",
142
+ default=1000,
143
+ type=int,
144
+ help="number of trials to do for random search",
145
+ )
146
+ group.add_argument(
147
+ "--share-weights", action="store_true", help="share weight2 and weight 3"
148
+ )
149
+ return group
fairseq-0.10.2/examples/noisychannel/rerank_score_bw.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from contextlib import redirect_stdout
8
+
9
+ from fairseq import options
10
+ from fairseq_cli import generate
11
+
12
+ from . import rerank_options, rerank_utils
13
+
14
+
15
+ def score_bw(args):
16
+ if args.backwards1:
17
+ scorer1_src = args.target_lang
18
+ scorer1_tgt = args.source_lang
19
+ else:
20
+ scorer1_src = args.source_lang
21
+ scorer1_tgt = args.target_lang
22
+
23
+ if args.score_model2 is not None:
24
+ if args.backwards2:
25
+ scorer2_src = args.target_lang
26
+ scorer2_tgt = args.source_lang
27
+ else:
28
+ scorer2_src = args.source_lang
29
+ scorer2_tgt = args.target_lang
30
+
31
+ rerank1_is_gen = (
32
+ args.gen_model == args.score_model1 and args.source_prefix_frac is None
33
+ )
34
+ rerank2_is_gen = (
35
+ args.gen_model == args.score_model2 and args.source_prefix_frac is None
36
+ )
37
+
38
+ (
39
+ pre_gen,
40
+ left_to_right_preprocessed_dir,
41
+ right_to_left_preprocessed_dir,
42
+ backwards_preprocessed_dir,
43
+ lm_preprocessed_dir,
44
+ ) = rerank_utils.get_directories(
45
+ args.data_dir_name,
46
+ args.num_rescore,
47
+ args.gen_subset,
48
+ args.gen_model_name,
49
+ args.shard_id,
50
+ args.num_shards,
51
+ args.sampling,
52
+ args.prefix_len,
53
+ args.target_prefix_frac,
54
+ args.source_prefix_frac,
55
+ )
56
+
57
+ score1_file = rerank_utils.rescore_file_name(
58
+ pre_gen,
59
+ args.prefix_len,
60
+ args.model1_name,
61
+ target_prefix_frac=args.target_prefix_frac,
62
+ source_prefix_frac=args.source_prefix_frac,
63
+ backwards=args.backwards1,
64
+ )
65
+
66
+ if args.score_model2 is not None:
67
+ score2_file = rerank_utils.rescore_file_name(
68
+ pre_gen,
69
+ args.prefix_len,
70
+ args.model2_name,
71
+ target_prefix_frac=args.target_prefix_frac,
72
+ source_prefix_frac=args.source_prefix_frac,
73
+ backwards=args.backwards2,
74
+ )
75
+
76
+ if args.right_to_left1:
77
+ rerank_data1 = right_to_left_preprocessed_dir
78
+ elif args.backwards1:
79
+ rerank_data1 = backwards_preprocessed_dir
80
+ else:
81
+ rerank_data1 = left_to_right_preprocessed_dir
82
+
83
+ gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
84
+ if not rerank1_is_gen and not os.path.isfile(score1_file):
85
+ print("STEP 4: score the translations for model 1")
86
+
87
+ model_param1 = [
88
+ "--path",
89
+ args.score_model1,
90
+ "--source-lang",
91
+ scorer1_src,
92
+ "--target-lang",
93
+ scorer1_tgt,
94
+ ]
95
+ gen_model1_param = [rerank_data1] + gen_param + model_param1
96
+
97
+ gen_parser = options.get_generation_parser()
98
+ input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)
99
+
100
+ with open(score1_file, "w") as f:
101
+ with redirect_stdout(f):
102
+ generate.main(input_args)
103
+
104
+ if (
105
+ args.score_model2 is not None
106
+ and not os.path.isfile(score2_file)
107
+ and not rerank2_is_gen
108
+ ):
109
+ print("STEP 4: score the translations for model 2")
110
+
111
+ if args.right_to_left2:
112
+ rerank_data2 = right_to_left_preprocessed_dir
113
+ elif args.backwards2:
114
+ rerank_data2 = backwards_preprocessed_dir
115
+ else:
116
+ rerank_data2 = left_to_right_preprocessed_dir
117
+
118
+ model_param2 = [
119
+ "--path",
120
+ args.score_model2,
121
+ "--source-lang",
122
+ scorer2_src,
123
+ "--target-lang",
124
+ scorer2_tgt,
125
+ ]
126
+ gen_model2_param = [rerank_data2] + gen_param + model_param2
127
+
128
+ gen_parser = options.get_generation_parser()
129
+ input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)
130
+
131
+ with open(score2_file, "w") as f:
132
+ with redirect_stdout(f):
133
+ generate.main(input_args)
134
+
135
+
136
+ def cli_main():
137
+ parser = rerank_options.get_reranking_parser()
138
+ args = options.parse_args_and_arch(parser)
139
+ score_bw(args)
140
+
141
+
142
+ if __name__ == "__main__":
143
+ cli_main()
fairseq-0.10.2/examples/noisychannel/rerank_tune.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import random
8
+
9
+ import numpy as np
10
+ from fairseq import options
11
+
12
+ from . import rerank, rerank_options
13
+
14
+
15
+ def random_search(args):
16
+ param_values = []
17
+ tuneable_parameters = ["lenpen", "weight1", "weight2", "weight3"]
18
+ initial_params = [args.lenpen, args.weight1, args.weight2, args.weight3]
19
+ for i, elem in enumerate(initial_params):
20
+ if type(elem) is not list:
21
+ initial_params[i] = [elem]
22
+ else:
23
+ initial_params[i] = elem
24
+
25
+ tune_parameters = args.tune_param.copy()
26
+ for i in range(len(args.tune_param)):
27
+ assert args.upper_bound[i] >= args.lower_bound[i]
28
+ index = tuneable_parameters.index(args.tune_param[i])
29
+ del tuneable_parameters[index]
30
+ del initial_params[index]
31
+
32
+ tune_parameters += tuneable_parameters
33
+ param_values += initial_params
34
+ random.seed(args.seed)
35
+
36
+ random_params = np.array(
37
+ [
38
+ [
39
+ random.uniform(args.lower_bound[i], args.upper_bound[i])
40
+ for i in range(len(args.tune_param))
41
+ ]
42
+ for k in range(args.num_trials)
43
+ ]
44
+ )
45
+ set_params = np.array(
46
+ [
47
+ [initial_params[i][0] for i in range(len(tuneable_parameters))]
48
+ for k in range(args.num_trials)
49
+ ]
50
+ )
51
+ random_params = np.concatenate((random_params, set_params), 1)
52
+
53
+ rerank_args = vars(args).copy()
54
+ if args.nbest_list:
55
+ rerank_args["gen_subset"] = "test"
56
+ else:
57
+ rerank_args["gen_subset"] = args.tune_subset
58
+
59
+ for k in range(len(tune_parameters)):
60
+ rerank_args[tune_parameters[k]] = list(random_params[:, k])
61
+
62
+ if args.share_weights:
63
+ k = tune_parameters.index("weight2")
64
+ rerank_args["weight3"] = list(random_params[:, k])
65
+
66
+ rerank_args = argparse.Namespace(**rerank_args)
67
+ best_lenpen, best_weight1, best_weight2, best_weight3, best_score = rerank.rerank(
68
+ rerank_args
69
+ )
70
+ rerank_args = vars(args).copy()
71
+ rerank_args["lenpen"] = [best_lenpen]
72
+ rerank_args["weight1"] = [best_weight1]
73
+ rerank_args["weight2"] = [best_weight2]
74
+ rerank_args["weight3"] = [best_weight3]
75
+
76
+ # write the hypothesis from the valid set from the best trial
77
+
78
+ if args.gen_subset != "valid":
79
+ rerank_args["gen_subset"] = "valid"
80
+ rerank_args = argparse.Namespace(**rerank_args)
81
+ rerank.rerank(rerank_args)
82
+
83
+ # test with the best hyperparameters on gen subset
84
+ rerank_args = vars(args).copy()
85
+ rerank_args["gen_subset"] = args.gen_subset
86
+ rerank_args["lenpen"] = [best_lenpen]
87
+ rerank_args["weight1"] = [best_weight1]
88
+ rerank_args["weight2"] = [best_weight2]
89
+ rerank_args["weight3"] = [best_weight3]
90
+ rerank_args = argparse.Namespace(**rerank_args)
91
+ rerank.rerank(rerank_args)
92
+
93
+
94
+ def cli_main():
95
+ parser = rerank_options.get_tuning_parser()
96
+ args = options.parse_args_and_arch(parser)
97
+
98
+ random_search(args)
99
+
100
+
101
+ if __name__ == "__main__":
102
+ cli_main()
fairseq-0.10.2/examples/noisychannel/rerank_utils.py ADDED
@@ -0,0 +1,850 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import os
8
+ import re
9
+ import subprocess
10
+ from contextlib import redirect_stdout
11
+
12
+ from fairseq import options
13
+ from fairseq_cli import eval_lm, preprocess
14
+
15
+
16
+ def reprocess(fle):
17
+ # takes in a file of generate.py translation generate_output
18
+ # returns a source dict and hypothesis dict, where keys are the ID num (as a string)
19
+ # and values and the corresponding source and translation. There may be several translations
20
+ # per source, so the values for hypothesis_dict are lists.
21
+ # parses output of generate.py
22
+
23
+ with open(fle, "r") as f:
24
+ txt = f.read()
25
+
26
+ """reprocess generate.py output"""
27
+ p = re.compile(r"[STHP][-]\d+\s*")
28
+ hp = re.compile(r"(\s*[-]?\d+[.]?\d+\s*)|(\s*(-inf)\s*)")
29
+ source_dict = {}
30
+ hypothesis_dict = {}
31
+ score_dict = {}
32
+ target_dict = {}
33
+ pos_score_dict = {}
34
+ lines = txt.split("\n")
35
+
36
+ for line in lines:
37
+ line += "\n"
38
+ prefix = re.search(p, line)
39
+ if prefix is not None:
40
+ assert len(prefix.group()) > 2, "prefix id not found"
41
+ _, j = prefix.span()
42
+ id_num = prefix.group()[2:]
43
+ id_num = int(id_num)
44
+ line_type = prefix.group()[0]
45
+ if line_type == "H":
46
+ h_txt = line[j:]
47
+ hypo = re.search(hp, h_txt)
48
+ assert (
49
+ hypo is not None
50
+ ), "regular expression failed to find the hypothesis scoring"
51
+ _, i = hypo.span()
52
+ score = hypo.group()
53
+ if id_num in hypothesis_dict:
54
+ hypothesis_dict[id_num].append(h_txt[i:])
55
+ score_dict[id_num].append(float(score))
56
+ else:
57
+ hypothesis_dict[id_num] = [h_txt[i:]]
58
+ score_dict[id_num] = [float(score)]
59
+
60
+ elif line_type == "S":
61
+ source_dict[id_num] = line[j:]
62
+ elif line_type == "T":
63
+ target_dict[id_num] = line[j:]
64
+ elif line_type == "P":
65
+ pos_scores = (line[j:]).split()
66
+ pos_scores = [float(x) for x in pos_scores]
67
+ if id_num in pos_score_dict:
68
+ pos_score_dict[id_num].append(pos_scores)
69
+ else:
70
+ pos_score_dict[id_num] = [pos_scores]
71
+
72
+ return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
73
+
74
+
75
+ def reprocess_nbest(fle):
76
+ """reprocess interactive.py output"""
77
+ with open(fle, "r") as f:
78
+ txt = f.read()
79
+
80
+ source_dict = {}
81
+ hypothesis_dict = {}
82
+ score_dict = {}
83
+ target_dict = {}
84
+ pos_score_dict = {}
85
+ lines = txt.split("\n")
86
+
87
+ hp = re.compile(r"[-]?\d+[.]?\d+")
88
+ j = -1
89
+
90
+ for _i, line in enumerate(lines):
91
+ line += "\n"
92
+ line_type = line[0]
93
+
94
+ if line_type == "H":
95
+ hypo = re.search(hp, line)
96
+ _, start_index = hypo.span()
97
+ score = hypo.group()
98
+ if j in score_dict:
99
+ score_dict[j].append(float(score))
100
+ hypothesis_dict[j].append(line[start_index:].strip("\t"))
101
+ else:
102
+ score_dict[j] = [float(score)]
103
+ hypothesis_dict[j] = [line[start_index:].strip("\t")]
104
+ elif line_type == "O":
105
+ j += 1
106
+ source_dict[j] = line[2:]
107
+ # we don't have the targets for interactive.py
108
+ target_dict[j] = "filler"
109
+
110
+ elif line_type == "P":
111
+ pos_scores = [float(pos_score) for pos_score in line.split()[1:]]
112
+ if j in pos_score_dict:
113
+ pos_score_dict[j].append(pos_scores)
114
+ else:
115
+ pos_score_dict[j] = [pos_scores]
116
+
117
+ assert source_dict.keys() == hypothesis_dict.keys()
118
+ assert source_dict.keys() == pos_score_dict.keys()
119
+ assert source_dict.keys() == score_dict.keys()
120
+
121
+ return source_dict, hypothesis_dict, score_dict, target_dict, pos_score_dict
122
+
123
+
124
+ def write_reprocessed(
125
+ sources,
126
+ hypos,
127
+ targets,
128
+ source_outfile,
129
+ hypo_outfile,
130
+ target_outfile,
131
+ right_to_left=False,
132
+ prefix_len=None,
133
+ bpe_symbol=None,
134
+ target_prefix_frac=None,
135
+ source_prefix_frac=None,
136
+ ):
137
+
138
+ """writes nbest hypothesis for rescoring"""
139
+ assert not (
140
+ prefix_len is not None and target_prefix_frac is not None
141
+ ), "in writing reprocessed, only one type of prefix may be used"
142
+ assert not (
143
+ prefix_len is not None and source_prefix_frac is not None
144
+ ), "in writing reprocessed, only one type of prefix may be used"
145
+ assert not (
146
+ target_prefix_frac is not None and source_prefix_frac is not None
147
+ ), "in writing reprocessed, only one type of prefix may be used"
148
+
149
+ with open(source_outfile, "w") as source_file, open(
150
+ hypo_outfile, "w"
151
+ ) as hypo_file, open(target_outfile, "w") as target_file:
152
+
153
+ assert len(sources) == len(hypos), "sources and hypos list length mismatch"
154
+ if right_to_left:
155
+ for i in range(len(sources)):
156
+ for j in range(len(hypos[i])):
157
+ if prefix_len is None:
158
+ hypo_file.write(make_right_to_left(hypos[i][j]) + "\n")
159
+ else:
160
+ raise NotImplementedError()
161
+ source_file.write(make_right_to_left(sources[i]) + "\n")
162
+ target_file.write(make_right_to_left(targets[i]) + "\n")
163
+ else:
164
+ for i in sorted(sources.keys()):
165
+ for j in range(len(hypos[i])):
166
+ if prefix_len is not None:
167
+ shortened = (
168
+ get_prefix_no_bpe(hypos[i][j], bpe_symbol, prefix_len)
169
+ + "\n"
170
+ )
171
+ hypo_file.write(shortened)
172
+ source_file.write(sources[i])
173
+ target_file.write(targets[i])
174
+ elif target_prefix_frac is not None:
175
+ num_words, shortened, num_bpe_tokens = calc_length_from_frac(
176
+ hypos[i][j], target_prefix_frac, bpe_symbol
177
+ )
178
+ shortened += "\n"
179
+ hypo_file.write(shortened)
180
+ source_file.write(sources[i])
181
+ target_file.write(targets[i])
182
+ elif source_prefix_frac is not None:
183
+ num_words, shortened, num_bpe_tokensn = calc_length_from_frac(
184
+ sources[i], source_prefix_frac, bpe_symbol
185
+ )
186
+ shortened += "\n"
187
+ hypo_file.write(hypos[i][j])
188
+ source_file.write(shortened)
189
+ target_file.write(targets[i])
190
+ else:
191
+ hypo_file.write(hypos[i][j])
192
+ source_file.write(sources[i])
193
+ target_file.write(targets[i])
194
+
195
+
196
+ def calc_length_from_frac(bpe_sentence, prefix_frac, bpe_symbol):
197
+ # return number of words, (not bpe tokens) that we want
198
+ no_bpe_sen = remove_bpe(bpe_sentence, bpe_symbol)
199
+ len_sen = len(no_bpe_sen.split())
200
+
201
+ num_words = math.ceil(len_sen * prefix_frac)
202
+ prefix = get_prefix_no_bpe(bpe_sentence, bpe_symbol, num_words)
203
+ num_bpe_tokens = len(prefix.split())
204
+ return num_words, prefix, num_bpe_tokens
205
+
206
+
207
+ def get_prefix(sentence, prefix_len):
208
+ """assuming no bpe, gets the prefix of the sentence with prefix_len words"""
209
+ tokens = sentence.strip("\n").split()
210
+ if prefix_len >= len(tokens):
211
+ return sentence.strip("\n")
212
+ else:
213
+ return " ".join(tokens[:prefix_len])
214
+
215
+
216
+ def get_prefix_no_bpe(sentence, bpe_symbol, prefix_len):
217
+ if bpe_symbol is None:
218
+ return get_prefix(sentence, prefix_len)
219
+ else:
220
+ return " ".join(get_prefix_from_len(sentence.split(), bpe_symbol, prefix_len))
221
+
222
+
223
+ def get_prefix_from_len(sentence, bpe_symbol, prefix_len):
224
+ """get the prefix of sentence with bpe, with prefix len in terms of words, not bpe tokens"""
225
+ bpe_count = sum([bpe_symbol.strip(" ") in t for t in sentence[:prefix_len]])
226
+ if bpe_count == 0:
227
+ return sentence[:prefix_len]
228
+ else:
229
+ return sentence[:prefix_len] + get_prefix_from_len(
230
+ sentence[prefix_len:], bpe_symbol, bpe_count
231
+ )
232
+
233
+
234
+ def get_num_bpe_tokens_from_len(sentence, bpe_symbol, prefix_len):
235
+ """given a prefix length in terms of words, return the number of bpe tokens"""
236
+ prefix = get_prefix_no_bpe(sentence, bpe_symbol, prefix_len)
237
+ assert len(remove_bpe(prefix, bpe_symbol).split()) <= prefix_len
238
+ return len(prefix.split(" "))
239
+
240
+
241
+ def make_right_to_left(line):
242
+ tokens = line.split()
243
+ tokens.reverse()
244
+ new_line = " ".join(tokens)
245
+ return new_line
246
+
247
+
248
+ def remove_bpe(line, bpe_symbol):
249
+ line = line.replace("\n", "")
250
+ line = (line + " ").replace(bpe_symbol, "").rstrip()
251
+ return line + ("\n")
252
+
253
+
254
+ def remove_bpe_dict(pred_dict, bpe_symbol):
255
+ new_dict = {}
256
+ for i in pred_dict:
257
+ if type(pred_dict[i]) == list:
258
+ new_list = [remove_bpe(elem, bpe_symbol) for elem in pred_dict[i]]
259
+ new_dict[i] = new_list
260
+ else:
261
+ new_dict[i] = remove_bpe(pred_dict[i], bpe_symbol)
262
+ return new_dict
263
+
264
+
265
+ def parse_bleu_scoring(line):
266
+ p = re.compile(r"(BLEU4 = )\d+[.]\d+")
267
+ res = re.search(p, line)
268
+ assert res is not None, line
269
+ return float(res.group()[8:])
270
+
271
+
272
+ def get_full_from_prefix(hypo_prefix, hypos):
273
+ """given a hypo prefix, recover the first hypo from the list of complete hypos beginning with that prefix"""
274
+ for hypo in hypos:
275
+ hypo_prefix = hypo_prefix.strip("\n")
276
+ len_prefix = len(hypo_prefix)
277
+ if hypo[:len_prefix] == hypo_prefix:
278
+ return hypo
279
+ # no match found
280
+ raise Exception()
281
+
282
+
283
+ def get_score(
284
+ a,
285
+ b,
286
+ c,
287
+ target_len,
288
+ bitext_score1,
289
+ bitext_score2=None,
290
+ lm_score=None,
291
+ lenpen=None,
292
+ src_len=None,
293
+ tgt_len=None,
294
+ bitext1_backwards=False,
295
+ bitext2_backwards=False,
296
+ normalize=False,
297
+ ):
298
+ if bitext1_backwards:
299
+ bitext1_norm = src_len
300
+ else:
301
+ bitext1_norm = tgt_len
302
+ if bitext_score2 is not None:
303
+ if bitext2_backwards:
304
+ bitext2_norm = src_len
305
+ else:
306
+ bitext2_norm = tgt_len
307
+ else:
308
+ bitext2_norm = 1
309
+ bitext_score2 = 0
310
+ if normalize:
311
+ score = (
312
+ a * bitext_score1 / bitext1_norm
313
+ + b * bitext_score2 / bitext2_norm
314
+ + c * lm_score / src_len
315
+ )
316
+ else:
317
+ score = a * bitext_score1 + b * bitext_score2 + c * lm_score
318
+
319
+ if lenpen is not None:
320
+ score /= (target_len) ** float(lenpen)
321
+
322
+ return score
323
+
324
+
325
+ class BitextOutput(object):
326
+ def __init__(
327
+ self,
328
+ output_file,
329
+ backwards,
330
+ right_to_left,
331
+ bpe_symbol,
332
+ prefix_len=None,
333
+ target_prefix_frac=None,
334
+ source_prefix_frac=None,
335
+ ):
336
+ """process output from rescoring"""
337
+ source, hypo, score, target, pos_score = reprocess(output_file)
338
+ if backwards:
339
+ self.hypo_fracs = source_prefix_frac
340
+ else:
341
+ self.hypo_fracs = target_prefix_frac
342
+
343
+ # remove length penalty so we can use raw scores
344
+ score, num_bpe_tokens = get_score_from_pos(
345
+ pos_score, prefix_len, hypo, bpe_symbol, self.hypo_fracs, backwards
346
+ )
347
+ source_lengths = {}
348
+ target_lengths = {}
349
+
350
+ assert hypo.keys() == source.keys(), "key mismatch"
351
+ if backwards:
352
+ tmp = hypo
353
+ hypo = source
354
+ source = tmp
355
+ for i in source:
356
+ # since we are reranking, there should only be one hypo per source sentence
357
+ if backwards:
358
+ len_src = len(source[i][0].split())
359
+ # record length without <eos>
360
+ if len_src == num_bpe_tokens[i][0] - 1:
361
+ source_lengths[i] = num_bpe_tokens[i][0] - 1
362
+ else:
363
+ source_lengths[i] = num_bpe_tokens[i][0]
364
+
365
+ target_lengths[i] = len(hypo[i].split())
366
+
367
+ source[i] = remove_bpe(source[i][0], bpe_symbol)
368
+ target[i] = remove_bpe(target[i], bpe_symbol)
369
+ hypo[i] = remove_bpe(hypo[i], bpe_symbol)
370
+
371
+ score[i] = float(score[i][0])
372
+ pos_score[i] = pos_score[i][0]
373
+
374
+ else:
375
+ len_tgt = len(hypo[i][0].split())
376
+ # record length without <eos>
377
+ if len_tgt == num_bpe_tokens[i][0] - 1:
378
+ target_lengths[i] = num_bpe_tokens[i][0] - 1
379
+ else:
380
+ target_lengths[i] = num_bpe_tokens[i][0]
381
+
382
+ source_lengths[i] = len(source[i].split())
383
+
384
+ if right_to_left:
385
+ source[i] = remove_bpe(make_right_to_left(source[i]), bpe_symbol)
386
+ target[i] = remove_bpe(make_right_to_left(target[i]), bpe_symbol)
387
+ hypo[i] = remove_bpe(make_right_to_left(hypo[i][0]), bpe_symbol)
388
+ score[i] = float(score[i][0])
389
+ pos_score[i] = pos_score[i][0]
390
+ else:
391
+ assert (
392
+ len(hypo[i]) == 1
393
+ ), "expected only one hypothesis per source sentence"
394
+ source[i] = remove_bpe(source[i], bpe_symbol)
395
+ target[i] = remove_bpe(target[i], bpe_symbol)
396
+ hypo[i] = remove_bpe(hypo[i][0], bpe_symbol)
397
+ score[i] = float(score[i][0])
398
+ pos_score[i] = pos_score[i][0]
399
+
400
+ self.rescore_source = source
401
+ self.rescore_hypo = hypo
402
+ self.rescore_score = score
403
+ self.rescore_target = target
404
+ self.rescore_pos_score = pos_score
405
+ self.backwards = backwards
406
+ self.right_to_left = right_to_left
407
+ self.target_lengths = target_lengths
408
+ self.source_lengths = source_lengths
409
+
410
+
411
+ class BitextOutputFromGen(object):
412
+ def __init__(
413
+ self,
414
+ predictions_bpe_file,
415
+ bpe_symbol=None,
416
+ nbest=False,
417
+ prefix_len=None,
418
+ target_prefix_frac=None,
419
+ ):
420
+ if nbest:
421
+ (
422
+ pred_source,
423
+ pred_hypo,
424
+ pred_score,
425
+ pred_target,
426
+ pred_pos_score,
427
+ ) = reprocess_nbest(predictions_bpe_file)
428
+ else:
429
+ pred_source, pred_hypo, pred_score, pred_target, pred_pos_score = reprocess(
430
+ predictions_bpe_file
431
+ )
432
+
433
+ assert len(pred_source) == len(pred_hypo)
434
+ assert len(pred_source) == len(pred_score)
435
+ assert len(pred_source) == len(pred_target)
436
+ assert len(pred_source) == len(pred_pos_score)
437
+
438
+ # remove length penalty so we can use raw scores
439
+ pred_score, num_bpe_tokens = get_score_from_pos(
440
+ pred_pos_score, prefix_len, pred_hypo, bpe_symbol, target_prefix_frac, False
441
+ )
442
+
443
+ self.source = pred_source
444
+ self.target = pred_target
445
+ self.score = pred_score
446
+ self.pos_score = pred_pos_score
447
+ self.hypo = pred_hypo
448
+ self.target_lengths = {}
449
+ self.source_lengths = {}
450
+
451
+ self.no_bpe_source = remove_bpe_dict(pred_source.copy(), bpe_symbol)
452
+ self.no_bpe_hypo = remove_bpe_dict(pred_hypo.copy(), bpe_symbol)
453
+ self.no_bpe_target = remove_bpe_dict(pred_target.copy(), bpe_symbol)
454
+
455
+ # indexes to match those from the rescoring models
456
+ self.rescore_source = {}
457
+ self.rescore_target = {}
458
+ self.rescore_pos_score = {}
459
+ self.rescore_hypo = {}
460
+ self.rescore_score = {}
461
+ self.num_hypos = {}
462
+ self.backwards = False
463
+ self.right_to_left = False
464
+
465
+ index = 0
466
+
467
+ for i in sorted(pred_source.keys()):
468
+ for j in range(len(pred_hypo[i])):
469
+
470
+ self.target_lengths[index] = len(self.hypo[i][j].split())
471
+ self.source_lengths[index] = len(self.source[i].split())
472
+
473
+ self.rescore_source[index] = self.no_bpe_source[i]
474
+ self.rescore_target[index] = self.no_bpe_target[i]
475
+ self.rescore_hypo[index] = self.no_bpe_hypo[i][j]
476
+ self.rescore_score[index] = float(pred_score[i][j])
477
+ self.rescore_pos_score[index] = pred_pos_score[i][j]
478
+ self.num_hypos[index] = len(pred_hypo[i])
479
+ index += 1
480
+
481
+
482
+ def get_score_from_pos(
483
+ pos_score_dict, prefix_len, hypo_dict, bpe_symbol, hypo_frac, backwards
484
+ ):
485
+ score_dict = {}
486
+ num_bpe_tokens_dict = {}
487
+ assert prefix_len is None or hypo_frac is None
488
+ for key in pos_score_dict:
489
+ score_dict[key] = []
490
+ num_bpe_tokens_dict[key] = []
491
+ for i in range(len(pos_score_dict[key])):
492
+ if prefix_len is not None and not backwards:
493
+ num_bpe_tokens = get_num_bpe_tokens_from_len(
494
+ hypo_dict[key][i], bpe_symbol, prefix_len
495
+ )
496
+ score_dict[key].append(sum(pos_score_dict[key][i][:num_bpe_tokens]))
497
+ num_bpe_tokens_dict[key].append(num_bpe_tokens)
498
+ elif hypo_frac is not None:
499
+ num_words, shortened, hypo_prefix_len = calc_length_from_frac(
500
+ hypo_dict[key][i], hypo_frac, bpe_symbol
501
+ )
502
+ score_dict[key].append(sum(pos_score_dict[key][i][:hypo_prefix_len]))
503
+ num_bpe_tokens_dict[key].append(hypo_prefix_len)
504
+ else:
505
+ score_dict[key].append(sum(pos_score_dict[key][i]))
506
+ num_bpe_tokens_dict[key].append(len(pos_score_dict[key][i]))
507
+ return score_dict, num_bpe_tokens_dict
508
+
509
+
510
+ class LMOutput(object):
511
+ def __init__(
512
+ self,
513
+ lm_score_file,
514
+ lm_dict=None,
515
+ prefix_len=None,
516
+ bpe_symbol=None,
517
+ target_prefix_frac=None,
518
+ ):
519
+ (
520
+ lm_sentences,
521
+ lm_sen_scores,
522
+ lm_sen_pos_scores,
523
+ lm_no_bpe_sentences,
524
+ lm_bpe_tokens,
525
+ ) = parse_lm(
526
+ lm_score_file,
527
+ prefix_len=prefix_len,
528
+ bpe_symbol=bpe_symbol,
529
+ target_prefix_frac=target_prefix_frac,
530
+ )
531
+
532
+ self.sentences = lm_sentences
533
+ self.score = lm_sen_scores
534
+ self.pos_score = lm_sen_pos_scores
535
+ self.lm_dict = lm_dict
536
+ self.no_bpe_sentences = lm_no_bpe_sentences
537
+ self.bpe_tokens = lm_bpe_tokens
538
+
539
+
540
+ def parse_lm(input_file, prefix_len=None, bpe_symbol=None, target_prefix_frac=None):
541
+ """parse output of eval_lm"""
542
+ with open(input_file, "r") as f:
543
+ text = f.readlines()
544
+ text = text[7:]
545
+ cleaned_text = text[:-2]
546
+
547
+ sentences = {}
548
+ sen_scores = {}
549
+ sen_pos_scores = {}
550
+ no_bpe_sentences = {}
551
+ num_bpe_tokens_dict = {}
552
+ for _i, line in enumerate(cleaned_text):
553
+ tokens = line.split()
554
+ if tokens[0].isdigit():
555
+ line_id = int(tokens[0])
556
+ scores = [float(x[1:-1]) for x in tokens[2::2]]
557
+ sentences[line_id] = " ".join(tokens[1::2][:-1]) + "\n"
558
+ if bpe_symbol is not None:
559
+ # exclude <eos> symbol to match output from generate.py
560
+ bpe_sen = " ".join(tokens[1::2][:-1]) + "\n"
561
+ no_bpe_sen = remove_bpe(bpe_sen, bpe_symbol)
562
+ no_bpe_sentences[line_id] = no_bpe_sen
563
+
564
+ if prefix_len is not None:
565
+ num_bpe_tokens = get_num_bpe_tokens_from_len(
566
+ bpe_sen, bpe_symbol, prefix_len
567
+ )
568
+ sen_scores[line_id] = sum(scores[:num_bpe_tokens])
569
+ num_bpe_tokens_dict[line_id] = num_bpe_tokens
570
+ elif target_prefix_frac is not None:
571
+ num_words, shortened, target_prefix_len = calc_length_from_frac(
572
+ bpe_sen, target_prefix_frac, bpe_symbol
573
+ )
574
+ sen_scores[line_id] = sum(scores[:target_prefix_len])
575
+ num_bpe_tokens_dict[line_id] = target_prefix_len
576
+ else:
577
+ sen_scores[line_id] = sum(scores)
578
+ num_bpe_tokens_dict[line_id] = len(scores)
579
+
580
+ sen_pos_scores[line_id] = scores
581
+
582
+ return sentences, sen_scores, sen_pos_scores, no_bpe_sentences, num_bpe_tokens_dict
583
+
584
+
585
+ def get_directories(
586
+ data_dir_name,
587
+ num_rescore,
588
+ gen_subset,
589
+ fw_name,
590
+ shard_id,
591
+ num_shards,
592
+ sampling=False,
593
+ prefix_len=None,
594
+ target_prefix_frac=None,
595
+ source_prefix_frac=None,
596
+ ):
597
+ nbest_file_id = (
598
+ "nbest_"
599
+ + str(num_rescore)
600
+ + "_subset_"
601
+ + gen_subset
602
+ + "_fw_name_"
603
+ + fw_name
604
+ + "_shard_"
605
+ + str(shard_id)
606
+ + "_of_"
607
+ + str(num_shards)
608
+ )
609
+
610
+ if sampling:
611
+ nbest_file_id += "_sampling"
612
+
613
+ # the directory containing all information for this nbest list
614
+ pre_gen = (
615
+ os.path.join(os.path.dirname(__file__))
616
+ + "/rerank_data/"
617
+ + data_dir_name
618
+ + "/"
619
+ + nbest_file_id
620
+ )
621
+ # the directory to store the preprocessed nbest list, for left to right rescoring
622
+ left_to_right_preprocessed_dir = pre_gen + "/left_to_right_preprocessed"
623
+ if source_prefix_frac is not None:
624
+ left_to_right_preprocessed_dir = (
625
+ left_to_right_preprocessed_dir + "/prefix_frac" + str(source_prefix_frac)
626
+ )
627
+ # the directory to store the preprocessed nbest list, for right to left rescoring
628
+ right_to_left_preprocessed_dir = pre_gen + "/right_to_left_preprocessed"
629
+ # the directory to store the preprocessed nbest list, for backwards rescoring
630
+ backwards_preprocessed_dir = pre_gen + "/backwards"
631
+ if target_prefix_frac is not None:
632
+ backwards_preprocessed_dir = (
633
+ backwards_preprocessed_dir + "/prefix_frac" + str(target_prefix_frac)
634
+ )
635
+ elif prefix_len is not None:
636
+ backwards_preprocessed_dir = (
637
+ backwards_preprocessed_dir + "/prefix_" + str(prefix_len)
638
+ )
639
+
640
+ # the directory to store the preprocessed nbest list, for rescoring with P(T)
641
+ lm_preprocessed_dir = pre_gen + "/lm_preprocessed"
642
+
643
+ return (
644
+ pre_gen,
645
+ left_to_right_preprocessed_dir,
646
+ right_to_left_preprocessed_dir,
647
+ backwards_preprocessed_dir,
648
+ lm_preprocessed_dir,
649
+ )
650
+
651
+
652
+ def lm_scoring(
653
+ preprocess_directory,
654
+ bpe_status,
655
+ gen_output,
656
+ pre_gen,
657
+ cur_lm_dict,
658
+ cur_lm_name,
659
+ cur_language_model,
660
+ cur_lm_bpe_code,
661
+ batch_size,
662
+ lm_score_file,
663
+ target_lang,
664
+ source_lang,
665
+ prefix_len=None,
666
+ ):
667
+ if prefix_len is not None:
668
+ assert (
669
+ bpe_status == "different"
670
+ ), "bpe status must be different to use prefix len"
671
+ if bpe_status == "no bpe":
672
+ # run lm on output without bpe
673
+ write_reprocessed(
674
+ gen_output.no_bpe_source,
675
+ gen_output.no_bpe_hypo,
676
+ gen_output.no_bpe_target,
677
+ pre_gen + "/rescore_data_no_bpe.de",
678
+ pre_gen + "/rescore_data_no_bpe.en",
679
+ pre_gen + "/reference_file_no_bpe",
680
+ )
681
+
682
+ preprocess_lm_param = [
683
+ "--only-source",
684
+ "--trainpref",
685
+ pre_gen + "/rescore_data_no_bpe." + target_lang,
686
+ "--srcdict",
687
+ cur_lm_dict,
688
+ "--destdir",
689
+ preprocess_directory,
690
+ ]
691
+ preprocess_parser = options.get_preprocessing_parser()
692
+ input_args = preprocess_parser.parse_args(preprocess_lm_param)
693
+ preprocess.main(input_args)
694
+
695
+ eval_lm_param = [
696
+ preprocess_directory,
697
+ "--path",
698
+ cur_language_model,
699
+ "--output-word-probs",
700
+ "--batch-size",
701
+ str(batch_size),
702
+ "--max-tokens",
703
+ "1024",
704
+ "--sample-break-mode",
705
+ "eos",
706
+ "--gen-subset",
707
+ "train",
708
+ ]
709
+
710
+ eval_lm_parser = options.get_eval_lm_parser()
711
+ input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
712
+
713
+ with open(lm_score_file, "w") as f:
714
+ with redirect_stdout(f):
715
+ eval_lm.main(input_args)
716
+
717
+ elif bpe_status == "shared":
718
+ preprocess_lm_param = [
719
+ "--only-source",
720
+ "--trainpref",
721
+ pre_gen + "/rescore_data." + target_lang,
722
+ "--srcdict",
723
+ cur_lm_dict,
724
+ "--destdir",
725
+ preprocess_directory,
726
+ ]
727
+ preprocess_parser = options.get_preprocessing_parser()
728
+ input_args = preprocess_parser.parse_args(preprocess_lm_param)
729
+ preprocess.main(input_args)
730
+
731
+ eval_lm_param = [
732
+ preprocess_directory,
733
+ "--path",
734
+ cur_language_model,
735
+ "--output-word-probs",
736
+ "--batch-size",
737
+ str(batch_size),
738
+ "--sample-break-mode",
739
+ "eos",
740
+ "--gen-subset",
741
+ "train",
742
+ ]
743
+
744
+ eval_lm_parser = options.get_eval_lm_parser()
745
+ input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
746
+
747
+ with open(lm_score_file, "w") as f:
748
+ with redirect_stdout(f):
749
+ eval_lm.main(input_args)
750
+
751
+ elif bpe_status == "different":
752
+ rescore_file = pre_gen + "/rescore_data_no_bpe"
753
+ rescore_bpe = pre_gen + "/rescore_data_new_bpe"
754
+
755
+ rescore_file += "."
756
+ rescore_bpe += "."
757
+
758
+ write_reprocessed(
759
+ gen_output.no_bpe_source,
760
+ gen_output.no_bpe_hypo,
761
+ gen_output.no_bpe_target,
762
+ rescore_file + source_lang,
763
+ rescore_file + target_lang,
764
+ pre_gen + "/reference_file_no_bpe",
765
+ bpe_symbol=None,
766
+ )
767
+
768
+ # apply LM bpe to nbest list
769
+ bpe_src_param = [
770
+ "-c",
771
+ cur_lm_bpe_code,
772
+ "--input",
773
+ rescore_file + target_lang,
774
+ "--output",
775
+ rescore_bpe + target_lang,
776
+ ]
777
+ subprocess.call(
778
+ [
779
+ "python",
780
+ os.path.join(
781
+ os.path.dirname(__file__), "subword-nmt/subword_nmt/apply_bpe.py"
782
+ ),
783
+ ]
784
+ + bpe_src_param,
785
+ shell=False,
786
+ )
787
+ # uncomment to use fastbpe instead of subword-nmt bpe
788
+ # bpe_src_param = [rescore_bpe+target_lang, rescore_file+target_lang, cur_lm_bpe_code]
789
+ # subprocess.call(["/private/home/edunov/fastBPE/fast", "applybpe"] + bpe_src_param, shell=False)
790
+
791
+ preprocess_dir = preprocess_directory
792
+
793
+ preprocess_lm_param = [
794
+ "--only-source",
795
+ "--trainpref",
796
+ rescore_bpe + target_lang,
797
+ "--srcdict",
798
+ cur_lm_dict,
799
+ "--destdir",
800
+ preprocess_dir,
801
+ ]
802
+ preprocess_parser = options.get_preprocessing_parser()
803
+ input_args = preprocess_parser.parse_args(preprocess_lm_param)
804
+ preprocess.main(input_args)
805
+
806
+ eval_lm_param = [
807
+ preprocess_dir,
808
+ "--path",
809
+ cur_language_model,
810
+ "--output-word-probs",
811
+ "--batch-size",
812
+ str(batch_size),
813
+ "--max-tokens",
814
+ "1024",
815
+ "--sample-break-mode",
816
+ "eos",
817
+ "--gen-subset",
818
+ "train",
819
+ ]
820
+
821
+ eval_lm_parser = options.get_eval_lm_parser()
822
+ input_args = options.parse_args_and_arch(eval_lm_parser, eval_lm_param)
823
+
824
+ with open(lm_score_file, "w") as f:
825
+ with redirect_stdout(f):
826
+ eval_lm.main(input_args)
827
+
828
+
829
+ def rescore_file_name(
830
+ nbest_dir,
831
+ prefix_len,
832
+ scorer_name,
833
+ lm_file=False,
834
+ target_prefix_frac=None,
835
+ source_prefix_frac=None,
836
+ backwards=None,
837
+ ):
838
+ if lm_file:
839
+ score_file = nbest_dir + "/lm_score_translations_model_" + scorer_name + ".txt"
840
+ else:
841
+ score_file = nbest_dir + "/" + scorer_name + "_score_translations.txt"
842
+ if backwards:
843
+ if prefix_len is not None:
844
+ score_file += "prefix_len" + str(prefix_len)
845
+ elif target_prefix_frac is not None:
846
+ score_file += "target_prefix_frac" + str(target_prefix_frac)
847
+ else:
848
+ if source_prefix_frac is not None:
849
+ score_file += "source_prefix_frac" + str(source_prefix_frac)
850
+ return score_file
fairseq-0.10.2/examples/paraphraser/README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Paraphrasing with round-trip translation and mixture of experts
2
+
3
+ Machine translation models can be used to paraphrase text by translating it to
4
+ an intermediate language and back (round-trip translation).
5
+
6
+ This example shows how to paraphrase text by first passing it to an
7
+ English-French translation model, followed by a French-English [mixture of
8
+ experts translation model](/examples/translation_moe).
9
+
10
+ ##### 0. Setup
11
+
12
+ Clone fairseq from source and install necessary dependencies:
13
+ ```bash
14
+ git clone https://github.com/pytorch/fairseq.git
15
+ cd fairseq
16
+ pip install --editable .
17
+ pip install sacremoses sentencepiece
18
+ ```
19
+
20
+ ##### 1. Download models
21
+ ```bash
22
+ wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.en-fr.tar.gz
23
+ wget https://dl.fbaipublicfiles.com/fairseq/models/paraphraser.fr-en.hMoEup.tar.gz
24
+ tar -xzvf paraphraser.en-fr.tar.gz
25
+ tar -xzvf paraphraser.fr-en.hMoEup.tar.gz
26
+ ```
27
+
28
+ ##### 2. Paraphrase
29
+ ```bash
30
+ python examples/paraphraser/paraphrase.py \
31
+ --en2fr paraphraser.en-fr \
32
+ --fr2en paraphraser.fr-en.hMoEup
33
+ # Example input:
34
+ # The new date for the Games, postponed for a year in response to the coronavirus pandemic, gives athletes time to recalibrate their training schedules.
35
+ # Example outputs:
36
+ # Delayed one year in response to the coronavirus pandemic, the new date of the Games gives athletes time to rebalance their training schedule.
37
+ # The new date of the Games, which was rescheduled one year in response to the coronavirus (CV) pandemic, gives athletes time to rebalance their training schedule.
38
+ # The new date of the Games, postponed one year in response to the coronavirus pandemic, provides athletes with time to rebalance their training schedule.
39
+ # The Games' new date, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
40
+ # The new Games date, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
41
+ # The new date of the Games, which was postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their training schedule.
42
+ # The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to rebalance their training schedule.
43
+ # The new date of the Games, postponed one year in response to the coronavirus pandemic, gives athletes time to re-balance their training schedule.
44
+ # The new date of the Games, postponed one year in response to the coronavirus pandemic, gives the athletes time to rebalance their schedule of training.
45
+ # The new date of the Games, postponed one year in response to the pandemic of coronavirus, gives the athletes time to rebalance their training schedule.
46
+ ```
fairseq-0.10.2/examples/paraphraser/paraphrase.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3 -u
2
+
3
+ import argparse
4
+ import fileinput
5
+ import logging
6
+ import os
7
+ import sys
8
+
9
+ from fairseq.models.transformer import TransformerModel
10
+
11
+
12
+ logging.getLogger().setLevel(logging.INFO)
13
+
14
+
15
+ def main():
16
+ parser = argparse.ArgumentParser(description="")
17
+ parser.add_argument("--en2fr", required=True, help="path to en2fr model")
18
+ parser.add_argument(
19
+ "--fr2en", required=True, help="path to fr2en mixture of experts model"
20
+ )
21
+ parser.add_argument(
22
+ "--user-dir", help="path to fairseq examples/translation_moe/src directory"
23
+ )
24
+ parser.add_argument(
25
+ "--num-experts",
26
+ type=int,
27
+ default=10,
28
+ help="(keep at 10 unless using a different model)",
29
+ )
30
+ parser.add_argument(
31
+ "files",
32
+ nargs="*",
33
+ default=["-"],
34
+ help='input files to paraphrase; "-" for stdin',
35
+ )
36
+ args = parser.parse_args()
37
+
38
+ if args.user_dir is None:
39
+ args.user_dir = os.path.join(
40
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), # examples/
41
+ "translation_moe",
42
+ "src",
43
+ )
44
+ if os.path.exists(args.user_dir):
45
+ logging.info("found user_dir:" + args.user_dir)
46
+ else:
47
+ raise RuntimeError(
48
+ "cannot find fairseq examples/translation_moe/src "
49
+ "(tried looking here: {})".format(args.user_dir)
50
+ )
51
+
52
+ logging.info("loading en2fr model from:" + args.en2fr)
53
+ en2fr = TransformerModel.from_pretrained(
54
+ model_name_or_path=args.en2fr,
55
+ tokenizer="moses",
56
+ bpe="sentencepiece",
57
+ ).eval()
58
+
59
+ logging.info("loading fr2en model from:" + args.fr2en)
60
+ fr2en = TransformerModel.from_pretrained(
61
+ model_name_or_path=args.fr2en,
62
+ tokenizer="moses",
63
+ bpe="sentencepiece",
64
+ user_dir=args.user_dir,
65
+ task="translation_moe",
66
+ ).eval()
67
+
68
+ def gen_paraphrases(en):
69
+ fr = en2fr.translate(en)
70
+ return [
71
+ fr2en.translate(fr, inference_step_args={"expert": i})
72
+ for i in range(args.num_experts)
73
+ ]
74
+
75
+ logging.info("Type the input sentence and press return:")
76
+ for line in fileinput.input(args.files):
77
+ line = line.strip()
78
+ if len(line) == 0:
79
+ continue
80
+ for paraphrase in gen_paraphrases(line):
81
+ print(paraphrase)
82
+
83
+
84
+ if __name__ == "__main__":
85
+ main()
fairseq-0.10.2/examples/simultaneous_translation/README.md ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Simultaneous Machine Translation
2
+
3
+ This directory contains the code for the paper [Monotonic Multihead Attention](https://openreview.net/forum?id=Hyg96gBKPS)
4
+
5
+ ## Prepare Data
6
+
7
+ [Please follow the instructions to download and preprocess the WMT'15 En-De dataset.](https://github.com/pytorch/fairseq/tree/simulastsharedtask/examples/translation#prepare-wmt14en2desh)
8
+
9
+ ## Training
10
+
11
+ - MMA-IL
12
+
13
+ ```shell
14
+ fairseq-train \
15
+ data-bin/wmt15_en_de_32k \
16
+ --simul-type infinite_lookback \
17
+ --user-dir $FAIRSEQ/example/simultaneous_translation \
18
+ --mass-preservation \
19
+ --criterion latency_augmented_label_smoothed_cross_entropy \
20
+ --latency-weight-avg 0.1 \
21
+ --max-update 50000 \
22
+ --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
23
+ --optimizer adam --adam-betas '(0.9, 0.98)' \
24
+ --lr-scheduler 'inverse_sqrt' \
25
+ --warmup-init-lr 1e-7 --warmup-updates 4000 \
26
+ --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
27
+ --dropout 0.3 \
28
+ --label-smoothing 0.1\
29
+ --max-tokens 3584
30
+ ```
31
+
32
+ - MMA-H
33
+
34
+ ```shell
35
+ fairseq-train \
36
+ data-bin/wmt15_en_de_32k \
37
+ --simul-type hard_aligned \
38
+ --user-dir $FAIRSEQ/example/simultaneous_translation \
39
+ --mass-preservation \
40
+ --criterion latency_augmented_label_smoothed_cross_entropy \
41
+ --latency-weight-var 0.1 \
42
+ --max-update 50000 \
43
+ --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
44
+ --optimizer adam --adam-betas '(0.9, 0.98)' \
45
+ --lr-scheduler 'inverse_sqrt' \
46
+ --warmup-init-lr 1e-7 --warmup-updates 4000 \
47
+ --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
48
+ --dropout 0.3 \
49
+ --label-smoothing 0.1\
50
+ --max-tokens 3584
51
+ ```
52
+
53
+ - wait-k
54
+
55
+ ```shell
56
+ fairseq-train \
57
+ data-bin/wmt15_en_de_32k \
58
+ --simul-type wait-k \
59
+ --waitk-lagging 3 \
60
+ --user-dir $FAIRSEQ/example/simultaneous_translation \
61
+ --mass-preservation \
62
+ --criterion latency_augmented_label_smoothed_cross_entropy \
63
+ --max-update 50000 \
64
+ --arch transformer_monotonic_iwslt_de_en save_dir_key=lambda \
65
+ --optimizer adam --adam-betas '(0.9, 0.98)' \
66
+ --lr-scheduler 'inverse_sqrt' \
67
+ --warmup-init-lr 1e-7 --warmup-updates 4000 \
68
+ --lr 5e-4 --min-lr 1e-9 --clip-norm 0.0 --weight-decay 0.0001\
69
+ --dropout 0.3 \
70
+ --label-smoothing 0.1\
71
+ --max-tokens 3584
72
+ ```
73
+
74
+
75
+ ## Evaluation
76
+
77
+ More details on evaluation can be found [here](https://github.com/pytorch/fairseq/blob/simulastsharedtask/examples/simultaneous_translation/docs/evaluation.md)
78
+
79
+ ### Start the server
80
+
81
+ ```shell
82
+ python ./eval/server.py \
83
+ --src-file $SRC_FILE \
84
+ --ref-file $TGT_FILE
85
+ ```
86
+
87
+ ### Run the client
88
+
89
+ ```shell
90
+ python ./evaluate.py \
91
+ --data-bin data-bin/wmt15_en_de_32k \
92
+ --model-path ./checkpoints/checkpoint_best.pt
93
+ --scores --output $RESULT_DIR
94
+ ```
95
+
96
+ ### Run evaluation locally without server
97
+
98
+ ```shell
99
+ python ./eval/evaluate.py
100
+ --local \
101
+ --src-file $SRC_FILE \
102
+ --tgt-file $TGT_FILE \
103
+ --data-bin data-bin/wmt15_en_de_32k \
104
+ --model-path ./checkpoints/checkpoint_best.pt \
105
+ --scores --output $RESULT_DIR
106
+ ```
fairseq-0.10.2/examples/simultaneous_translation/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from . import criterions, eval, models # noqa
fairseq-0.10.2/examples/simultaneous_translation/docs/baseline.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # **Baseline Simultaneous Translation**
2
+ ---
3
+
4
+ This is an instruction of training and evaluating a *wait-k* simultanoes LSTM model on MUST-C English-Gernam Dataset.
5
+
6
+ [STACL: Simultaneous Translation with Implicit Anticipation and Controllable Latency using Prefix-to-Prefix Framework](https://https://www.aclweb.org/anthology/P19-1289/)
7
+
8
+
9
+ ## **Requirements**
10
+ Install fairseq (make sure to use the correct branch):
11
+ ```
12
+ git clone --branch simulastsharedtask git@github.com:pytorch/fairseq.git
13
+ cd fairseq
14
+ pip install -e .
15
+ ```
16
+
17
+ Assuming that fairseq is installed in a directory called `FAIRSEQ`.
18
+
19
+ Install SentencePiece. One easy way is to use anaconda:
20
+
21
+ ```
22
+ conda install -c powerai sentencepiece
23
+ ```
24
+
25
+ Download the MuST-C data for English-German available at https://ict.fbk.eu/must-c/.
26
+ We will assume that the data is downloaded in a directory called `DATA_ROOT`.
27
+
28
+
29
+ ## **Text-to-text Model**
30
+ ---
31
+ ### Data Preparation
32
+ Train a SentencePiece model:
33
+ ```shell
34
+ for lang in en de; do
35
+ python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \
36
+ --data-path $DATA_ROOT/data \
37
+ --vocab-size 10000 \
38
+ --max-frame 3000 \
39
+ --model-type unigram \
40
+ --lang $lang \
41
+ --out-path .
42
+ ```
43
+
44
+ Process the data with the SentencePiece model:
45
+ ```shell
46
+ proc_dir=proc
47
+ mkdir -p $proc_dir
48
+ for split in train dev tst-COMMON tst-HE; do
49
+ for lang in en de; do
50
+ spm_encode \
51
+ --model unigram-$lang-10000-3000/spm.model \
52
+ < $DATA_ROOT/data/$split/txt/$split.$lang \
53
+ > $proc_dir/$split.spm.$lang
54
+ done
55
+ done
56
+ ```
57
+
58
+ Binarize the data:
59
+
60
+ ```shell
61
+ proc_dir=proc
62
+ fairseq-preprocess \
63
+ --source-lang en --target-lang de \
64
+ --trainpref $proc_dir/train.spm \
65
+ --validpref $proc_dir/dev.spm \
66
+ --testpref $proc_dir/tst-COMMON.spm \
67
+ --thresholdtgt 0 \
68
+ --thresholdsrc 0 \
69
+ --workers 20 \
70
+ --destdir ./data-bin/mustc_en_de \
71
+ ```
72
+
73
+ ### Training
74
+
75
+
76
+ ```shell
77
+ mkdir -p checkpoints
78
+ CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \
79
+ --save-dir checkpoints \
80
+ --arch berard_simul_text_iwslt \
81
+ --simul-type waitk \
82
+ --waitk-lagging 2 \
83
+ --optimizer adam \
84
+ --max-epoch 100 \
85
+ --lr 0.001 \
86
+ --clip-norm 5.0 \
87
+ --batch-size 128 \
88
+ --log-format json \
89
+ --log-interval 10 \
90
+ --criterion cross_entropy_acc \
91
+ --user-dir $FAIRSEQ/examples/simultaneous_translation
92
+ ```
93
+
94
+ ## **Speech-to-text Model**
95
+ ---
96
+ ### Data Preparation
97
+ First, segment wav files.
98
+ ```shell
99
+ python $FAIRSEQ/examples/simultaneous_translation/data/segment_wav.py \
100
+ --datapath $DATA_ROOT
101
+ ```
102
+ Similar to text-to-text model, train a Sentencepiecemodel, but only train on German
103
+ ```Shell
104
+ python $FAIRSEQ/examples/simultaneous_translation/data/train_spm.py \
105
+ --data-path $DATA_ROOT/data \
106
+ --vocab-size 10000 \
107
+ --max-frame 3000 \
108
+ --model-type unigram \
109
+ --lang $lang \
110
+ --out-path .
111
+ ```
112
+ ## Training
113
+ ```shell
114
+ mkdir -p checkpoints
115
+ CUDA_VISIBLE_DEVICES=1 python $FAIRSEQ/train.py data-bin/mustc_en_de \
116
+ --save-dir checkpoints \
117
+ --arch berard_simul_text_iwslt \
118
+ --waitk-lagging 2 \
119
+ --waitk-stride 10 \
120
+ --input-feat-per-channel 40 \
121
+ --encoder-hidden-size 512 \
122
+ --output-layer-dim 128 \
123
+ --decoder-num-layers 3 \
124
+ --task speech_translation \
125
+ --user-dir $FAIRSEQ/examples/simultaneous_translation
126
+ --optimizer adam \
127
+ --max-epoch 100 \
128
+ --lr 0.001 \
129
+ --clip-norm 5.0 \
130
+ --batch-size 128 \
131
+ --log-format json \
132
+ --log-interval 10 \
133
+ --criterion cross_entropy_acc \
134
+ --user-dir $FAIRSEQ/examples/simultaneous_translation
135
+ ```
136
+
137
+ ## Evaluation
138
+ ---
139
+ ### Evaluation Server
140
+ For text translation models, the server is set up as follow give input file and reference file.
141
+
142
+ ``` shell
143
+ python ./eval/server.py \
144
+ --hostname localhost \
145
+ --port 12321 \
146
+ --src-file $DATA_ROOT/data/dev/txt/dev.en \
147
+ --ref-file $DATA_ROOT/data/dev/txt/dev.de
148
+ ```
149
+ For speech translation models, the input is the data direcrory.
150
+ ``` shell
151
+ python ./eval/server.py \
152
+ --hostname localhost \
153
+ --port 12321 \
154
+ --ref-file $DATA_ROOT \
155
+ --data-type speech
156
+ ```
157
+
158
+ ### Decode and Evaluate with Client
159
+ Once the server is set up, run client to evaluate translation quality and latency.
160
+ ```shell
161
+ # TEXT
162
+ python $fairseq_dir/examples/simultaneous_translation/evaluate.py \
163
+ data-bin/mustc_en_de \
164
+ --user-dir $FAIRSEQ/examples/simultaneous_translation \
165
+ --src-spm unigram-en-10000-3000/spm.model\
166
+ --tgt-spm unigram-de-10000-3000/spm.model\
167
+ -s en -t de \
168
+ --path checkpoints/checkpoint_best.pt
169
+
170
+ # SPEECH
171
+ python $fairseq_dir/examples/simultaneous_translation/evaluate.py \
172
+ data-bin/mustc_en_de \
173
+ --user-dir $FAIRSEQ/examples/simultaneous_translation \
174
+ --data-type speech \
175
+ --tgt-spm unigram-de-10000-3000/spm.model\
176
+ -s en -t de \
177
+ --path checkpoints/checkpoint_best.pt
178
+ ```
fairseq-0.10.2/examples/simultaneous_translation/docs/evaluation.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Introduction to evaluation interface
2
+ The simultaneous translation models from sharedtask participents are evaluated under a server-client protocol. The participents are requisted to plug in their own model API in the protocol, and submit a docker file.
3
+
4
+ ## Server-Client Protocol
5
+ An server-client protocol that will be used in evaluation. For example, when a *wait-k* model (k=3) translate the English sentence "Alice and Bob are good friends" to Genman sentence "Alice und Bob sind gute Freunde." , the evaluation process is shown as following figure.
6
+
7
+ While every time client needs to read a new state (word or speech utterence), a "GET" request is supposed to sent over to server. Whenever a new token is generated, a "SEND" request with the word predicted (untokenized word) will be sent to server immediately. The server can hence calculate both latency and BLEU score of the sentence.
8
+
9
+ ### Server
10
+ The server code is provided and can be set up directly locally for development purpose. For example, to evaluate a text simultaneous test set,
11
+
12
+ ```shell
13
+
14
+ python fairseq/examples/simultaneous_translation/eval/server.py \
15
+ --hostname local_host \
16
+ --port 1234 \
17
+ --src-file SRC_FILE \
18
+ --ref-file REF_FILE \
19
+ --data-type text \
20
+ ```
21
+ The state that server sent to client is has the following format
22
+ ```json
23
+ {
24
+ 'sent_id': Int,
25
+ 'segment_id': Int,
26
+ 'segment': String
27
+ }
28
+ ```
29
+
30
+ ### Client
31
+ The client will handle the evaluation process mentioned above. It should be out-of-box as well. The client's protocol is as following table
32
+
33
+ |Action|Content|
34
+ |:---:|:---:|
35
+ |Request new word / utterence| ```{key: "Get", value: None}```|
36
+ |Predict word "W"| ```{key: "SEND", value: "W"}```|
37
+
38
+
39
+
40
+ The core of the client module is the agent, which needs to be modified to different models accordingly. The abstract class of agent is as follow, the evaluation process happens in the `decode()` function.
41
+ ```python
42
+ class Agent(object):
43
+ "an agent needs to follow this pattern"
44
+ def __init__(self, *args, **kwargs):
45
+ ...
46
+
47
+ def init_states(self):
48
+ # Initializing states
49
+ ...
50
+
51
+ def update_states(self, states, new_state):
52
+ # Update states with given new state from server
53
+ # TODO (describe the states)
54
+ ...
55
+
56
+ def finish_eval(self, states, new_state):
57
+ # Check if evaluation is finished
58
+ ...
59
+
60
+ def policy(self, state: list) -> dict:
61
+ # Provide a action given current states
62
+ # The action can only be either
63
+ # {key: "GET", value: NONE}
64
+ # or
65
+ # {key: "SEND", value: W}
66
+ ...
67
+
68
+ def reset(self):
69
+ # Reset agent
70
+ ...
71
+
72
+ def decode(self, session):
73
+
74
+ states = self.init_states()
75
+ self.reset()
76
+
77
+ # Evaluataion protocol happens here
78
+ while True:
79
+ # Get action from the current states according to self.policy()
80
+ action = self.policy(states)
81
+
82
+ if action['key'] == GET:
83
+ # Read a new state from server
84
+ new_state = session.get_src()
85
+ states = self.update_states(states, new_state)
86
+
87
+ if self.finish_eval(states, new_state):
88
+ # End of document
89
+ break
90
+
91
+ elif action['key'] == SEND:
92
+ # Send a new prediction to server
93
+ session.send_hypo(action['value'])
94
+
95
+ # Clean the history, wait for next sentence
96
+ if action['value'] == DEFAULT_EOS:
97
+ states = self.init_states()
98
+ self.reset()
99
+ else:
100
+ raise NotImplementedError
101
+
102
+
103
+ ```
104
+ Here an implementation of agent of text [*wait-k* model](somelink). Notice that the tokenization is not considered.
105
+
106
+ ## Quality
107
+ The quality is measured by detokenized BLEU. So make sure that the predicted words sent to server are detokenized. An implementation is can be find [here](some link)
108
+
109
+ ## Latency
110
+ The latency metrics are
111
+ * Average Proportion
112
+ * Average Lagging
113
+ * Differentiable Average Lagging
114
+ Again Thery will also be evaluated on detokenized text.
115
+
fairseq-0.10.2/examples/simultaneous_translation/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
fairseq-0.10.2/examples/simultaneous_translation/eval/agents/word_splitter.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ class SubwordSplitter(object):
8
+ def process_line(self, string):
9
+ raise NotImplementedError
10
+
11
+ def split(self, string):
12
+ raise NotImplementedError
13
+
14
+
15
+ class NoneWordSplitter(object):
16
+ def __init__(self, model):
17
+ pass
18
+
19
+ def split(self, string):
20
+ return [string]
21
+
22
+ def process_line(self, string):
23
+ return [string]
24
+
25
+ def finished_word(self, string):
26
+ return True
27
+
28
+ def merge(self, list_of_string):
29
+ return "".join(list_of_string)
30
+
31
+ def last_full_word_step(self, tokens, step):
32
+ return len(tokens)
33
+
34
+ def end_idx_last_full_word(self, tokens):
35
+ return len(tokens)
36
+
37
+
38
+ class BPEWordSplitter(object):
39
+ # TODO: lock back here
40
+ def __init__(self, model_path):
41
+ super().__init__()
42
+ from subword_nmt.apply_bpe import BPE
43
+
44
+ with open(model_path) as f:
45
+ self.model = BPE(f)
46
+
47
+ def split(self, string):
48
+ return self.model.process_line(string).split()
49
+
50
+ def end_idx_last_full_word(self, tokens):
51
+ # Begin of word indices
52
+ bow_indices = [0] + [i + 1 for i, t in enumerate(tokens[1:]) if t[-2:] != "@@"]
53
+
54
+ if len(bow_indices) < 2:
55
+ return 0
56
+ else:
57
+ return bow_indices[-1]
58
+
59
+ def merge(self, list_of_string):
60
+ return " ".join([item.replace("@@", "") for item in list_of_string])
61
+
62
+
63
+ class SentencePieceModelWordSplitter(object):
64
+ def __init__(self, model_path):
65
+ super().__init__()
66
+ import sentencepiece as spm
67
+
68
+ self.model = spm.SentencePieceProcessor()
69
+ self.model.Load(model_path)
70
+
71
+ def split(self, string):
72
+ return self.model.EncodeAsPieces(string)
73
+
74
+ def end_idx_last_full_word(self, tokens):
75
+ # Begin of word indices
76
+ bow_indices = [i for i, t in enumerate(tokens) if t[0] == "\u2581"]
77
+
78
+ if len(bow_indices) < 2:
79
+ return 0
80
+ else:
81
+ return bow_indices[-1]
82
+
83
+ def merge(self, list_of_string):
84
+ return self.model.DecodePieces(list_of_string)
85
+
86
+
87
+ SPLITTER_DICT = {
88
+ None: NoneWordSplitter,
89
+ "BPE": BPEWordSplitter,
90
+ "SentencePieceModel": SentencePieceModelWordSplitter,
91
+ }
fairseq-0.10.2/examples/simultaneous_translation/eval/client.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Optional
7
+
8
+ import requests
9
+ from scorers import build_scorer
10
+
11
+
12
+ class SimulSTEvaluationService(object):
13
+ DEFAULT_HOSTNAME = "localhost"
14
+ DEFAULT_PORT = 12321
15
+
16
+ def __init__(self, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT):
17
+ self.hostname = hostname
18
+ self.port = port
19
+ self.base_url = f"http://{self.hostname}:{self.port}"
20
+
21
+ def __enter__(self):
22
+ self.new_session()
23
+
24
+ def __exit__(self, exc_type, exc_val, exc_tb):
25
+ pass
26
+
27
+ def new_session(self):
28
+ # start eval session
29
+ url = f"{self.base_url}"
30
+
31
+ try:
32
+ _ = requests.post(url)
33
+ except Exception as e:
34
+ print(f"Failed to start an evaluation session: {e}")
35
+
36
+ print("Evaluation session started.")
37
+ return self
38
+
39
+ def get_scores(self):
40
+ # end eval session
41
+ url = f"{self.base_url}/result"
42
+ try:
43
+ r = requests.get(url)
44
+ print("Scores: {}".format(r.json()))
45
+ print("Evaluation session finished.")
46
+ except Exception as e:
47
+ print(f"Failed to end an evaluation session: {e}")
48
+
49
+ def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
50
+ url = f"{self.base_url}/src"
51
+ params = {"sent_id": sent_id}
52
+ if extra_params is not None:
53
+ for key in extra_params.keys():
54
+ params[key] = extra_params[key]
55
+ try:
56
+ r = requests.get(url, params=params)
57
+ except Exception as e:
58
+ print(f"Failed to request a source segment: {e}")
59
+ return r.json()
60
+
61
+ def send_hypo(self, sent_id: int, hypo: str) -> None:
62
+ url = f"{self.base_url}/hypo"
63
+ params = {"sent_id": sent_id}
64
+
65
+ try:
66
+ requests.put(url, params=params, data=hypo.encode("utf-8"))
67
+ except Exception as e:
68
+ print(f"Failed to send a translated segment: {e}")
69
+
70
+ def corpus_info(self):
71
+ url = f"{self.base_url}"
72
+ try:
73
+ r = requests.get(url)
74
+ except Exception as e:
75
+ print(f"Failed to request corpus information: {e}")
76
+
77
+ return r.json()
78
+
79
+
80
+ class SimulSTLocalEvaluationService(object):
81
+ def __init__(self, args):
82
+ self.scorer = build_scorer(args)
83
+
84
+ def get_scores(self):
85
+ return self.scorer.score()
86
+
87
+ def get_src(self, sent_id: int, extra_params: Optional[dict] = None) -> str:
88
+ if extra_params is not None:
89
+ segment_size = extra_params.get("segment_size", None)
90
+ else:
91
+ segment_size = None
92
+
93
+ return self.scorer.send_src(int(sent_id), segment_size)
94
+
95
+ def send_hypo(self, sent_id: int, hypo: str) -> None:
96
+ list_of_tokens = hypo.strip().split()
97
+ self.scorer.recv_hyp(sent_id, list_of_tokens)
98
+
99
+ def corpus_info(self):
100
+ return self.scorer.get_info()
fairseq-0.10.2/examples/simultaneous_translation/eval/eval_latency.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import json
8
+
9
+ import torch
10
+ from examples.simultaneous_translation.utils.latency import LatencyInference
11
+
12
+
13
+ LATENCY_METRICS = [
14
+ "differentiable_average_lagging",
15
+ "average_lagging",
16
+ "average_proportion",
17
+ ]
18
+
19
+
20
+ class LatencyScorer:
21
+ def __init__(self, start_from_zero=True):
22
+ self.recorder = []
23
+ self.scores = {}
24
+ self.scorer = LatencyInference()
25
+ self.start_from_zero = start_from_zero
26
+
27
+ def update_reorder(self, list_of_dict):
28
+ self.recorder = []
29
+ for info in list_of_dict:
30
+ delays = [int(x) - int(not self.start_from_zero) for x in info["delays"]]
31
+ delays = torch.LongTensor(delays).unsqueeze(0)
32
+ src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
33
+
34
+ self.recorder.append(self.scorer(delays, src_len))
35
+
36
+ def cal_latency(self):
37
+ self.scores = {}
38
+ for metric in LATENCY_METRICS:
39
+ self.scores[metric] = sum(
40
+ [x[metric][0, 0].item() for x in self.recorder]
41
+ ) / len(self.recorder)
42
+ return self.scores
43
+
44
+ @classmethod
45
+ def score(cls, list_of_dict, start_from_zero=True):
46
+ scorer_to_return = cls(start_from_zero)
47
+ scorer_to_return.update_reorder(list_of_dict)
48
+ scorer_to_return.cal_latency()
49
+ return scorer_to_return.scores
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--input", required=True)
55
+ parser.add_argument("--start-from-zero", action="store_true")
56
+ args = parser.parse_args()
57
+
58
+ scorer = LatencyInference()
59
+ recorder = []
60
+ with open(args.input, "r") as f:
61
+ for line in f:
62
+ info = json.loads(line)
63
+
64
+ delays = [int(x) - int(not args.start_from_zero) for x in info["delays"]]
65
+
66
+ delays = torch.LongTensor(delays).unsqueeze(0)
67
+
68
+ src_len = torch.LongTensor([info["src_len"]]).unsqueeze(0)
69
+
70
+ recorder.append(scorer(delays, src_len))
71
+
72
+ average_results = {}
73
+
74
+ for metric in LATENCY_METRICS:
75
+ average_results[metric] = sum([x[metric][0, 0].item() for x in recorder]) / len(
76
+ recorder
77
+ )
78
+ print(f"{metric}: {average_results[metric]}")
fairseq-0.10.2/examples/simultaneous_translation/eval/evaluate.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+
8
+ from agents import build_agent
9
+ from client import SimulSTEvaluationService, SimulSTLocalEvaluationService
10
+ from fairseq.registry import REGISTRIES
11
+
12
+
13
+ DEFAULT_HOSTNAME = "localhost"
14
+ DEFAULT_PORT = 12321
15
+
16
+
17
+ def get_args():
18
+ parser = argparse.ArgumentParser()
19
+
20
+ parser.add_argument(
21
+ "--hostname", type=str, default=DEFAULT_HOSTNAME, help="server hostname"
22
+ )
23
+ parser.add_argument(
24
+ "--port", type=int, default=DEFAULT_PORT, help="server port number"
25
+ )
26
+ parser.add_argument("--agent-type", default="simul_trans_text", help="Agent type")
27
+ parser.add_argument("--scorer-type", default="text", help="Scorer type")
28
+ parser.add_argument(
29
+ "--start-idx",
30
+ type=int,
31
+ default=0,
32
+ help="Start index of the sentence to evaluate",
33
+ )
34
+ parser.add_argument(
35
+ "--end-idx",
36
+ type=int,
37
+ default=float("inf"),
38
+ help="End index of the sentence to evaluate",
39
+ )
40
+ parser.add_argument(
41
+ "--scores", action="store_true", help="Request scores from server"
42
+ )
43
+ parser.add_argument("--reset-server", action="store_true", help="Reset the server")
44
+ parser.add_argument(
45
+ "--num-threads", type=int, default=10, help="Number of threads used by agent"
46
+ )
47
+ parser.add_argument(
48
+ "--local", action="store_true", default=False, help="Local evaluation"
49
+ )
50
+
51
+ args, _ = parser.parse_known_args()
52
+
53
+ for registry_name, REGISTRY in REGISTRIES.items():
54
+ choice = getattr(args, registry_name, None)
55
+ if choice is not None:
56
+ cls = REGISTRY["registry"][choice]
57
+ if hasattr(cls, "add_args"):
58
+ cls.add_args(parser)
59
+ args = parser.parse_args()
60
+
61
+ return args
62
+
63
+
64
+ if __name__ == "__main__":
65
+ args = get_args()
66
+
67
+ if args.local:
68
+ session = SimulSTLocalEvaluationService(args)
69
+ else:
70
+ session = SimulSTEvaluationService(args.hostname, args.port)
71
+
72
+ if args.reset_server:
73
+ session.new_session()
74
+
75
+ if args.agent_type is not None:
76
+ agent = build_agent(args)
77
+ agent.decode(session, args.start_idx, args.end_idx, args.num_threads)
78
+
79
+ if args.scores:
80
+ session.get_scores()
81
+ print(session.get_scores())
fairseq-0.10.2/examples/simultaneous_translation/eval/server.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+ import argparse
6
+ import json
7
+ import sys
8
+
9
+ from scorers import build_scorer
10
+ from tornado import ioloop, web
11
+
12
+
13
+ DEFAULT_HOSTNAME = "localhost"
14
+ DEFAULT_PORT = 12321
15
+
16
+
17
+ class ScorerHandler(web.RequestHandler):
18
+ def initialize(self, scorer):
19
+ self.scorer = scorer
20
+
21
+
22
+ class EvalSessionHandler(ScorerHandler):
23
+ def post(self):
24
+ self.scorer.reset()
25
+
26
+ def get(self):
27
+ r = json.dumps(self.scorer.get_info())
28
+ self.write(r)
29
+
30
+
31
+ class ResultHandler(ScorerHandler):
32
+ def get(self):
33
+ r = json.dumps(self.scorer.score())
34
+ self.write(r)
35
+
36
+
37
+ class SourceHandler(ScorerHandler):
38
+ def get(self):
39
+ sent_id = int(self.get_argument("sent_id"))
40
+ segment_size = None
41
+ if "segment_size" in self.request.arguments:
42
+ string = self.get_argument("segment_size")
43
+ if len(string) > 0:
44
+ segment_size = int(string)
45
+
46
+ r = json.dumps(self.scorer.send_src(int(sent_id), segment_size))
47
+
48
+ self.write(r)
49
+
50
+
51
+ class HypothesisHandler(ScorerHandler):
52
+ def put(self):
53
+ sent_id = int(self.get_argument("sent_id"))
54
+ list_of_tokens = self.request.body.decode("utf-8").strip().split()
55
+ self.scorer.recv_hyp(sent_id, list_of_tokens)
56
+
57
+
58
+ def add_args():
59
+ parser = argparse.ArgumentParser()
60
+ # fmt: off
61
+ parser.add_argument('--hostname', type=str, default=DEFAULT_HOSTNAME,
62
+ help='Server hostname')
63
+ parser.add_argument('--port', type=int, default=DEFAULT_PORT,
64
+ help='Server port number')
65
+
66
+ args, _ = parser.parse_known_args()
67
+ # fmt: on
68
+ return args
69
+
70
+
71
+ def start_server(scorer, hostname=DEFAULT_HOSTNAME, port=DEFAULT_PORT, debug=False):
72
+ app = web.Application(
73
+ [
74
+ (r"/result", ResultHandler, dict(scorer=scorer)),
75
+ (r"/src", SourceHandler, dict(scorer=scorer)),
76
+ (r"/hypo", HypothesisHandler, dict(scorer=scorer)),
77
+ (r"/", EvalSessionHandler, dict(scorer=scorer)),
78
+ ],
79
+ debug=debug,
80
+ )
81
+ app.listen(port, max_buffer_size=1024 ** 3)
82
+ sys.stdout.write(f"Evaluation Server Started. Listening to port {port}\n")
83
+ ioloop.IOLoop.current().start()
84
+
85
+
86
+ if __name__ == "__main__":
87
+ args = add_args()
88
+ scorer = build_scorer(args)
89
+ start_server(scorer, args.hostname, args.port, args.debug)
fairseq-0.10.2/examples/simultaneous_translation/models/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+
10
+ for file in os.listdir(os.path.dirname(__file__)):
11
+ if file.endswith(".py") and not file.startswith("_"):
12
+ model_name = file[: file.find(".py")]
13
+ importlib.import_module(
14
+ "examples.simultaneous_translation.models." + model_name
15
+ )
fairseq-0.10.2/examples/simultaneous_translation/models/transformer_monotonic_attention.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from examples.simultaneous_translation.modules.monotonic_transformer_layer import (
10
+ TransformerMonotonicDecoderLayer,
11
+ TransformerMonotonicEncoderLayer,
12
+ )
13
+ from fairseq.models import register_model, register_model_architecture
14
+ from fairseq.models.transformer import (
15
+ TransformerDecoder,
16
+ TransformerEncoder,
17
+ TransformerModel,
18
+ base_architecture,
19
+ transformer_iwslt_de_en,
20
+ transformer_vaswani_wmt_en_de_big,
21
+ )
22
+
23
+
24
+ DEFAULT_MAX_SOURCE_POSITIONS = 1024
25
+ DEFAULT_MAX_TARGET_POSITIONS = 1024
26
+
27
+
28
+ @register_model("transformer_unidirectional")
29
+ class TransformerUnidirectionalModel(TransformerModel):
30
+ @classmethod
31
+ def build_encoder(cls, args, src_dict, embed_tokens):
32
+ return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
33
+
34
+
35
+ @register_model("transformer_monotonic")
36
+ class TransformerMonotonicModel(TransformerModel):
37
+ @classmethod
38
+ def build_encoder(cls, args, src_dict, embed_tokens):
39
+ return TransformerMonotonicEncoder(args, src_dict, embed_tokens)
40
+
41
+ @classmethod
42
+ def build_decoder(cls, args, tgt_dict, embed_tokens):
43
+ return TransformerMonotonicDecoder(args, tgt_dict, embed_tokens)
44
+
45
+ def _indices_from_states(self, states):
46
+ if type(states["indices"]["src"]) == list:
47
+ if next(self.parameters()).is_cuda:
48
+ tensor = torch.cuda.LongTensor
49
+ else:
50
+ tensor = torch.LongTensor
51
+
52
+ src_indices = tensor(
53
+ [states["indices"]["src"][: 1 + states["steps"]["src"]]]
54
+ )
55
+
56
+ tgt_indices = tensor(
57
+ [[self.decoder.dictionary.eos()] + states["indices"]["tgt"]]
58
+ )
59
+ else:
60
+ src_indices = states["indices"]["src"][: 1 + states["steps"]["src"]]
61
+ tgt_indices = states["indices"]["tgt"]
62
+
63
+ return src_indices, None, tgt_indices
64
+
65
+ def predict_from_states(self, states):
66
+ decoder_states = self.decoder.output_layer(states["decoder_features"])
67
+ lprobs = self.get_normalized_probs([decoder_states[:, -1:]], log_probs=True)
68
+
69
+ index = lprobs.argmax(dim=-1)
70
+
71
+ token = self.decoder.dictionary.string(index)
72
+
73
+ return token, index[0, 0].item()
74
+
75
+ def decision_from_states(self, states):
76
+ """
77
+ This funcion take states dictionary as input, and gives the agent
78
+ a decision of whether read a token from server. Moreover, the decoder
79
+ states are also calculated here so we can directly generate a target
80
+ token without recompute every thing
81
+ """
82
+
83
+ self.eval()
84
+
85
+ if len(states["tokens"]["src"]) == 0:
86
+ return 0
87
+
88
+ src_indices, src_lengths, tgt_indices = self._indices_from_states(states)
89
+
90
+ # Update encoder states if needed
91
+ if (
92
+ "encoder_states" not in states
93
+ or states["encoder_states"][0].size(1) <= states["steps"]["src"]
94
+ ):
95
+ encoder_out_dict = self.encoder(src_indices, src_lengths)
96
+ states["encoder_states"] = encoder_out_dict
97
+ else:
98
+ encoder_out_dict = states["encoder_states"]
99
+
100
+ # online means we still need tokens to feed the model
101
+ states["model_states"]["online"] = not (
102
+ states["finish_read"]
103
+ and len(states["tokens"]["src"]) == states["steps"]["src"]
104
+ )
105
+
106
+ states["model_states"]["steps"] = states["steps"]
107
+
108
+ x, outputs = self.decoder.forward(
109
+ prev_output_tokens=tgt_indices,
110
+ encoder_out=encoder_out_dict,
111
+ incremental_state=states["model_states"],
112
+ features_only=True,
113
+ )
114
+
115
+ states["decoder_features"] = x
116
+
117
+ return outputs["action"]
118
+
119
+
120
+ class TransformerMonotonicEncoder(TransformerEncoder):
121
+ def __init__(self, args, dictionary, embed_tokens):
122
+ super().__init__(args, dictionary, embed_tokens)
123
+
124
+ self.dictionary = dictionary
125
+ self.layers = nn.ModuleList([])
126
+ self.layers.extend(
127
+ [TransformerMonotonicEncoderLayer(args) for i in range(args.encoder_layers)]
128
+ )
129
+
130
+
131
+ class TransformerMonotonicDecoder(TransformerDecoder):
132
+ """
133
+ Transformer decoder consisting of *args.decoder_layers* layers. Each layer
134
+ is a :class:`TransformerDecoderLayer`.
135
+
136
+ Args:
137
+ args (argparse.Namespace): parsed command-line arguments
138
+ dictionary (~fairseq.data.Dictionary): decoding dictionary
139
+ embed_tokens (torch.nn.Embedding): output embedding
140
+ no_encoder_attn (bool, optional): whether to attend to encoder outputs
141
+ (default: False).
142
+ """
143
+
144
+ def __init__(self, args, dictionary, embed_tokens, no_encoder_attn=False):
145
+ super().__init__(args, dictionary, embed_tokens, no_encoder_attn=False)
146
+
147
+ self.dictionary = dictionary
148
+ self.layers = nn.ModuleList([])
149
+ self.layers.extend(
150
+ [
151
+ TransformerMonotonicDecoderLayer(args, no_encoder_attn)
152
+ for _ in range(args.decoder_layers)
153
+ ]
154
+ )
155
+
156
+ def pre_attention(
157
+ self, prev_output_tokens, encoder_out_dict, incremental_state=None
158
+ ):
159
+ positions = (
160
+ self.embed_positions(
161
+ prev_output_tokens,
162
+ incremental_state=incremental_state,
163
+ )
164
+ if self.embed_positions is not None
165
+ else None
166
+ )
167
+
168
+ if incremental_state is not None:
169
+ prev_output_tokens = prev_output_tokens[:, -1:]
170
+ if positions is not None:
171
+ positions = positions[:, -1:]
172
+
173
+ # embed tokens and positions
174
+ x = self.embed_scale * self.embed_tokens(prev_output_tokens)
175
+
176
+ if self.project_in_dim is not None:
177
+ x = self.project_in_dim(x)
178
+
179
+ if positions is not None:
180
+ x += positions
181
+ x = self.dropout_module(x)
182
+
183
+ # B x T x C -> T x B x C
184
+ x = x.transpose(0, 1)
185
+
186
+ encoder_out = encoder_out_dict.encoder_out
187
+ encoder_padding_mask = encoder_out_dict.encoder_padding_mask
188
+
189
+ return x, encoder_out, encoder_padding_mask
190
+
191
+ def post_attention(self, x):
192
+ if self.layer_norm:
193
+ x = self.layer_norm(x)
194
+
195
+ # T x B x C -> B x T x C
196
+ x = x.transpose(0, 1)
197
+
198
+ if self.project_out_dim is not None:
199
+ x = self.project_out_dim(x)
200
+
201
+ return x
202
+
203
+ def extract_features(
204
+ self, prev_output_tokens, encoder_out, incremental_state=None, **unused
205
+ ):
206
+ """
207
+ Similar to *forward* but only return features.
208
+
209
+ Returns:
210
+ tuple:
211
+ - the decoder's features of shape `(batch, tgt_len, embed_dim)`
212
+ - a dictionary with any model-specific outputs
213
+ """
214
+ # incremental_state = None
215
+ (x, encoder_outs, encoder_padding_mask) = self.pre_attention(
216
+ prev_output_tokens, encoder_out, incremental_state
217
+ )
218
+ attn = None
219
+ inner_states = [x]
220
+ attn_list = []
221
+ step_list = []
222
+
223
+ for i, layer in enumerate(self.layers):
224
+
225
+ x, attn, _ = layer(
226
+ x=x,
227
+ encoder_out=encoder_outs,
228
+ encoder_padding_mask=encoder_padding_mask,
229
+ incremental_state=incremental_state,
230
+ self_attn_mask=self.buffered_future_mask(x)
231
+ if incremental_state is None
232
+ else None,
233
+ )
234
+
235
+ inner_states.append(x)
236
+ attn_list.append(attn)
237
+
238
+ if incremental_state is not None:
239
+ curr_steps = layer.get_steps(incremental_state)
240
+ step_list.append(curr_steps)
241
+
242
+ if incremental_state.get("online", False):
243
+ p_choose = (
244
+ attn["p_choose"].squeeze(0).squeeze(1).gather(1, curr_steps.t())
245
+ )
246
+
247
+ new_steps = curr_steps + (p_choose < 0.5).t().type_as(curr_steps)
248
+
249
+ if (new_steps >= incremental_state["steps"]["src"]).any():
250
+ # We need to prune the last self_attn saved_state
251
+ # if model decide not to read
252
+ # otherwise there will be duplicated saved_state
253
+ for j in range(i + 1):
254
+ self.layers[j].prune_incremental_state(incremental_state)
255
+
256
+ return x, {"action": 0}
257
+
258
+ if incremental_state is not None and not incremental_state.get("online", False):
259
+ # Here is for fast evaluation
260
+ fastest_step = (
261
+ torch.max(torch.cat(step_list, dim=1), dim=1, keepdim=True)[0] + 1
262
+ )
263
+
264
+ if "fastest_step" in incremental_state:
265
+ incremental_state["fastest_step"] = torch.cat(
266
+ [incremental_state["fastest_step"], fastest_step], dim=1
267
+ )
268
+ else:
269
+ incremental_state["fastest_step"] = fastest_step
270
+
271
+ x = self.post_attention(x)
272
+
273
+ return x, {
274
+ "action": 1,
275
+ "attn_list": attn_list,
276
+ "step_list": step_list,
277
+ "encoder_out": encoder_out,
278
+ "encoder_padding_mask": encoder_padding_mask,
279
+ }
280
+
281
+ def reorder_incremental_state(self, incremental_state, new_order):
282
+ super().reorder_incremental_state(incremental_state, new_order)
283
+ if "fastest_step" in incremental_state:
284
+ incremental_state["fastest_step"] = incremental_state[
285
+ "fastest_step"
286
+ ].index_select(0, new_order)
287
+
288
+
289
+ @register_model_architecture("transformer_monotonic", "transformer_monotonic")
290
+ def base_monotonic_rchitecture(args):
291
+ base_architecture(args)
292
+ args.encoder_unidirectional = getattr(args, "encoder_unidirectional", False)
293
+
294
+
295
+ @register_model_architecture(
296
+ "transformer_monotonic", "transformer_monotonic_iwslt_de_en"
297
+ )
298
+ def transformer_monotonic_iwslt_de_en(args):
299
+ transformer_iwslt_de_en(args)
300
+ base_monotonic_rchitecture(args)
301
+
302
+
303
+ # parameters used in the "Attention Is All You Need" paper (Vaswani et al., 2017)
304
+ @register_model_architecture(
305
+ "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_de_big"
306
+ )
307
+ def transformer_monotonic_vaswani_wmt_en_de_big(args):
308
+ transformer_vaswani_wmt_en_de_big(args)
309
+
310
+
311
+ @register_model_architecture(
312
+ "transformer_monotonic", "transformer_monotonic_vaswani_wmt_en_fr_big"
313
+ )
314
+ def transformer_monotonic_vaswani_wmt_en_fr_big(args):
315
+ transformer_monotonic_vaswani_wmt_en_fr_big(args)
316
+
317
+
318
+ @register_model_architecture(
319
+ "transformer_unidirectional", "transformer_unidirectional_iwslt_de_en"
320
+ )
321
+ def transformer_unidirectional_iwslt_de_en(args):
322
+ transformer_iwslt_de_en(args)
fairseq-0.10.2/examples/simultaneous_translation/modules/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import importlib
7
+ import os
8
+
9
+ from fairseq import registry
10
+
11
+
12
+ (
13
+ build_monotonic_attention,
14
+ register_monotonic_attention,
15
+ MONOTONIC_ATTENTION_REGISTRY,
16
+ _,
17
+ ) = registry.setup_registry("--simul-type")
18
+
19
+ for file in os.listdir(os.path.dirname(__file__)):
20
+ if file.endswith(".py") and not file.startswith("_"):
21
+ model_name = file[: file.find(".py")]
22
+ importlib.import_module(
23
+ "examples.simultaneous_translation.modules." + model_name
24
+ )
fairseq-0.10.2/examples/simultaneous_translation/modules/monotonic_multihead_attention.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from examples.simultaneous_translation.utils.functions import (
12
+ exclusive_cumprod,
13
+ lengths_to_mask,
14
+ )
15
+ from fairseq import utils
16
+ from fairseq.incremental_decoding_utils import with_incremental_state
17
+ from fairseq.modules import MultiheadAttention
18
+ from fairseq.utils import convert_padding_direction
19
+
20
+ from . import register_monotonic_attention
21
+
22
+
23
+ @with_incremental_state
24
+ class MonotonicAttention(nn.Module):
25
+ """
26
+ Abstract class of monotonic attentions
27
+ """
28
+
29
+ def __init__(self, args):
30
+ self.eps = args.attention_eps
31
+ self.mass_preservation = args.mass_preservation
32
+
33
+ self.noise_mean = args.noise_mean
34
+ self.noise_var = args.noise_var
35
+
36
+ self.energy_bias_init = args.energy_bias_init
37
+ self.energy_bias = (
38
+ nn.Parameter(self.energy_bias_init * torch.ones([1]))
39
+ if args.energy_bias is True
40
+ else 0
41
+ )
42
+
43
+ @staticmethod
44
+ def add_args(parser):
45
+ # fmt: off
46
+ parser.add_argument('--no-mass-preservation', action="store_false", dest="mass_preservation",
47
+ help='Do not stay on the last token when decoding')
48
+ parser.add_argument('--mass-preservation', action="store_true", dest="mass_preservation",
49
+ help='Stay on the last token when decoding')
50
+ parser.set_defaults(mass_preservation=True)
51
+
52
+ parser.add_argument('--noise-var', type=float, default=1.0,
53
+ help='Variance of discretness noise')
54
+ parser.add_argument('--noise-mean', type=float, default=0.0,
55
+ help='Mean of discretness noise')
56
+ parser.add_argument('--energy-bias', action="store_true", default=False,
57
+ help='Bias for energy')
58
+ parser.add_argument('--energy-bias-init', type=float, default=-2.0,
59
+ help='Initial value of the bias for energy')
60
+ parser.add_argument('--attention-eps', type=float, default=1e-6,
61
+ help='Epsilon when calculating expected attention')
62
+ # fmt: on
63
+
64
+ def p_choose(self, *args):
65
+ raise NotImplementedError
66
+
67
+ def input_projections(self, *args):
68
+ raise NotImplementedError
69
+
70
+ def attn_energy(self, q_proj, k_proj, key_padding_mask=None):
71
+ """
72
+ Calculating monotonic energies
73
+
74
+ ============================================================
75
+ Expected input size
76
+ q_proj: bsz * num_heads, tgt_len, self.head_dim
77
+ k_proj: bsz * num_heads, src_len, self.head_dim
78
+ key_padding_mask: bsz, src_len
79
+ attn_mask: tgt_len, src_len
80
+ """
81
+ bsz, tgt_len, embed_dim = q_proj.size()
82
+ bsz = bsz // self.num_heads
83
+ src_len = k_proj.size(1)
84
+
85
+ attn_energy = torch.bmm(q_proj, k_proj.transpose(1, 2)) + self.energy_bias
86
+
87
+ attn_energy = attn_energy.view(bsz, self.num_heads, tgt_len, src_len)
88
+
89
+ if key_padding_mask is not None:
90
+ attn_energy = attn_energy.masked_fill(
91
+ key_padding_mask.unsqueeze(1).unsqueeze(2).bool(),
92
+ float("-inf"),
93
+ )
94
+
95
+ return attn_energy
96
+
97
+ def expected_alignment_train(self, p_choose, key_padding_mask):
98
+ """
99
+ Calculating expected alignment for MMA
100
+ Mask is not need because p_choose will be 0 if masked
101
+
102
+ q_ij = (1 − p_{ij−1})q_{ij−1} + a+{i−1j}
103
+ a_ij = p_ij q_ij
104
+
105
+ parellel solution:
106
+ ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))
107
+
108
+ ============================================================
109
+ Expected input size
110
+ p_choose: bsz * num_heads, tgt_len, src_len
111
+ """
112
+
113
+ # p_choose: bsz * num_heads, tgt_len, src_len
114
+ bsz_num_heads, tgt_len, src_len = p_choose.size()
115
+
116
+ # cumprod_1mp : bsz * num_heads, tgt_len, src_len
117
+ cumprod_1mp = exclusive_cumprod(1 - p_choose, dim=2, eps=self.eps)
118
+ cumprod_1mp_clamp = torch.clamp(cumprod_1mp, self.eps, 1.0)
119
+
120
+ init_attention = p_choose.new_zeros([bsz_num_heads, 1, src_len])
121
+ init_attention[:, :, 0] = 1.0
122
+
123
+ previous_attn = [init_attention]
124
+
125
+ for i in range(tgt_len):
126
+ # p_choose: bsz * num_heads, tgt_len, src_len
127
+ # cumprod_1mp_clamp : bsz * num_heads, tgt_len, src_len
128
+ # previous_attn[i]: bsz * num_heads, 1, src_len
129
+ # alpha_i: bsz * num_heads, src_len
130
+ alpha_i = (
131
+ p_choose[:, i]
132
+ * cumprod_1mp[:, i]
133
+ * torch.cumsum(previous_attn[i][:, 0] / cumprod_1mp_clamp[:, i], dim=1)
134
+ ).clamp(0, 1.0)
135
+ previous_attn.append(alpha_i.unsqueeze(1))
136
+
137
+ # alpha: bsz * num_heads, tgt_len, src_len
138
+ alpha = torch.cat(previous_attn[1:], dim=1)
139
+
140
+ if self.mass_preservation:
141
+ # Last token has the residual probabilities
142
+ alpha[:, :, -1] = 1 - alpha[:, :, :-1].sum(dim=-1).clamp(0.0, 1.0)
143
+
144
+ assert not torch.isnan(alpha).any(), "NaN detected in alpha."
145
+
146
+ return alpha
147
+
148
+ def expected_alignment_infer(self, p_choose, key_padding_mask, incremental_state):
149
+ """
150
+ Calculating mo alignment for MMA during inference time
151
+
152
+ ============================================================
153
+ Expected input size
154
+ p_choose: bsz * num_heads, tgt_len, src_len
155
+ key_padding_mask: bsz * src_len
156
+ incremental_state: dict
157
+ """
158
+ # p_choose: bsz * self.num_heads, src_len
159
+ bsz_num_heads, tgt_len, src_len = p_choose.size()
160
+ # One token at a time
161
+ assert tgt_len == 1
162
+ p_choose = p_choose[:, 0, :]
163
+
164
+ monotonic_cache = self._get_monotonic_buffer(incremental_state)
165
+
166
+ # prev_monotonic_step: bsz, num_heads
167
+ bsz = bsz_num_heads // self.num_heads
168
+ prev_monotonic_step = monotonic_cache.get(
169
+ "step", p_choose.new_zeros([bsz, self.num_heads]).long()
170
+ )
171
+ bsz, num_heads = prev_monotonic_step.size()
172
+ assert num_heads == self.num_heads
173
+ assert bsz * num_heads == bsz_num_heads
174
+
175
+ # p_choose: bsz, num_heads, src_len
176
+ p_choose = p_choose.view(bsz, num_heads, src_len)
177
+
178
+ if key_padding_mask is not None:
179
+ src_lengths = src_len - key_padding_mask.sum(dim=1, keepdim=True).long()
180
+ else:
181
+ src_lengths = prev_monotonic_step.new_ones(bsz, 1) * src_len
182
+
183
+ # src_lengths: bsz, num_heads
184
+ src_lengths = src_lengths.expand_as(prev_monotonic_step)
185
+ # new_monotonic_step: bsz, num_heads
186
+ new_monotonic_step = prev_monotonic_step
187
+
188
+ step_offset = 0
189
+ if key_padding_mask is not None:
190
+ if key_padding_mask[:, 0].any():
191
+ # left_pad_source = True:
192
+ step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
193
+
194
+ max_steps = src_lengths - 1 if self.mass_preservation else src_lengths
195
+
196
+ # finish_read: bsz, num_heads
197
+ finish_read = new_monotonic_step.eq(max_steps)
198
+
199
+ while finish_read.sum().item() < bsz * self.num_heads:
200
+ # p_choose: bsz * self.num_heads, src_len
201
+ # only choose the p at monotonic steps
202
+ # p_choose_i: bsz , self.num_heads
203
+ p_choose_i = (
204
+ p_choose.gather(
205
+ 2,
206
+ (step_offset + new_monotonic_step)
207
+ .unsqueeze(2)
208
+ .clamp(0, src_len - 1),
209
+ )
210
+ ).squeeze(2)
211
+
212
+ action = (
213
+ (p_choose_i < 0.5)
214
+ .type_as(prev_monotonic_step)
215
+ .masked_fill(finish_read, 0)
216
+ )
217
+ # 1 x bsz
218
+ # sample actions on unfinished seq
219
+ # 1 means stay, finish reading
220
+ # 0 means leave, continue reading
221
+ # dist = torch.distributions.bernoulli.Bernoulli(p_choose)
222
+ # action = dist.sample().type_as(finish_read) * (1 - finish_read)
223
+
224
+ new_monotonic_step += action
225
+
226
+ finish_read = new_monotonic_step.eq(max_steps) | (action == 0)
227
+ # finish_read = (~ (finish_read.sum(dim=1, keepdim=True) < self.num_heads / 2)) | finish_read
228
+
229
+ monotonic_cache["step"] = new_monotonic_step
230
+
231
+ # alpha: bsz * num_heads, 1, src_len
232
+ # new_monotonic_step: bsz, num_heads
233
+ alpha = p_choose.new_zeros([bsz * self.num_heads, src_len]).scatter(
234
+ 1,
235
+ (step_offset + new_monotonic_step)
236
+ .view(bsz * self.num_heads, 1)
237
+ .clamp(0, src_len - 1),
238
+ 1,
239
+ )
240
+
241
+ if not self.mass_preservation:
242
+ alpha = alpha.masked_fill(
243
+ (new_monotonic_step == max_steps).view(bsz * self.num_heads, 1), 0
244
+ )
245
+
246
+ alpha = alpha.unsqueeze(1)
247
+
248
+ self._set_monotonic_buffer(incremental_state, monotonic_cache)
249
+
250
+ return alpha
251
+
252
+ def v_proj_output(self, value):
253
+ raise NotImplementedError
254
+
255
+ def forward(
256
+ self,
257
+ query,
258
+ key,
259
+ value,
260
+ key_padding_mask=None,
261
+ incremental_state=None,
262
+ *args,
263
+ **kwargs,
264
+ ):
265
+
266
+ tgt_len, bsz, embed_dim = query.size()
267
+ src_len = value.size(0)
268
+
269
+ # stepwise prob
270
+ # p_choose: bsz * self.num_heads, tgt_len, src_len
271
+ p_choose = self.p_choose(query, key, key_padding_mask)
272
+
273
+ # expected alignment alpha
274
+ # bsz * self.num_heads, tgt_len, src_len
275
+ if incremental_state is not None:
276
+ alpha = self.expected_alignment_infer(
277
+ p_choose, key_padding_mask, incremental_state
278
+ )
279
+ else:
280
+ alpha = self.expected_alignment_train(p_choose, key_padding_mask)
281
+
282
+ # expected attention beta
283
+ # bsz * self.num_heads, tgt_len, src_len
284
+ beta = self.expected_attention(
285
+ alpha, query, key, value, key_padding_mask, incremental_state
286
+ )
287
+
288
+ attn_weights = beta
289
+
290
+ v_proj = self.v_proj_output(value)
291
+ attn = torch.bmm(attn_weights.type_as(v_proj), v_proj)
292
+
293
+ attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
294
+
295
+ attn = self.out_proj(attn)
296
+
297
+ beta = beta.view(bsz, self.num_heads, tgt_len, src_len)
298
+ alpha = alpha.view(bsz, self.num_heads, tgt_len, src_len)
299
+ p_choose = p_choose.view(bsz, self.num_heads, tgt_len, src_len)
300
+
301
+ return attn, {"alpha": alpha, "beta": beta, "p_choose": p_choose}
302
+
303
+ def reorder_incremental_state(self, incremental_state, new_order):
304
+ """Reorder buffered internal state (for incremental generation)."""
305
+ super().reorder_incremental_state(incremental_state, new_order)
306
+ input_buffer = self._get_monotonic_buffer(incremental_state)
307
+ if input_buffer is not None:
308
+ for k in input_buffer.keys():
309
+ input_buffer[k] = input_buffer[k].index_select(0, new_order)
310
+ self._set_monotonic_buffer(incremental_state, input_buffer)
311
+
312
+ def _get_monotonic_buffer(self, incremental_state):
313
+ return (
314
+ utils.get_incremental_state(
315
+ self,
316
+ incremental_state,
317
+ "monotonic",
318
+ )
319
+ or {}
320
+ )
321
+
322
+ def _set_monotonic_buffer(self, incremental_state, buffer):
323
+ utils.set_incremental_state(
324
+ self,
325
+ incremental_state,
326
+ "monotonic",
327
+ buffer,
328
+ )
329
+
330
+ def get_pointer(self, incremental_state):
331
+ return (
332
+ utils.get_incremental_state(
333
+ self,
334
+ incremental_state,
335
+ "monotonic",
336
+ )
337
+ or {}
338
+ )
339
+
340
+ def get_fastest_pointer(self, incremental_state):
341
+ return self.get_pointer(incremental_state)["step"].max(0)[0]
342
+
343
+ def set_pointer(self, incremental_state, p_choose):
344
+ curr_pointer = self.get_pointer(incremental_state)
345
+ if len(curr_pointer) == 0:
346
+ buffer = torch.zeros_like(p_choose)
347
+ else:
348
+ buffer = self.get_pointer(incremental_state)["step"]
349
+
350
+ buffer += (p_choose < 0.5).type_as(buffer)
351
+
352
+ utils.set_incremental_state(
353
+ self,
354
+ incremental_state,
355
+ "monotonic",
356
+ {"step": buffer},
357
+ )
358
+
359
+
360
+ @register_monotonic_attention("hard_aligned")
361
+ class MonotonicMultiheadAttentionHard(MonotonicAttention, MultiheadAttention):
362
+ def __init__(self, args):
363
+ MultiheadAttention.__init__(
364
+ self,
365
+ embed_dim=args.decoder_embed_dim,
366
+ num_heads=args.decoder_attention_heads,
367
+ kdim=getattr(args, "encoder_embed_dim", None),
368
+ vdim=getattr(args, "encoder_embed_dim", None),
369
+ dropout=args.attention_dropout,
370
+ encoder_decoder_attention=True,
371
+ )
372
+
373
+ MonotonicAttention.__init__(self, args)
374
+
375
+ self.k_in_proj = {"monotonic": self.k_proj}
376
+ self.q_in_proj = {"monotonic": self.q_proj}
377
+ self.v_in_proj = {"output": self.v_proj}
378
+
379
+ def input_projections(self, query, key, value, name):
380
+ """
381
+ Prepare inputs for multihead attention
382
+
383
+ ============================================================
384
+ Expected input size
385
+ query: tgt_len, bsz, embed_dim
386
+ key: src_len, bsz, embed_dim
387
+ value: src_len, bsz, embed_dim
388
+ name: monotonic or soft
389
+ """
390
+
391
+ if query is not None:
392
+ bsz = query.size(1)
393
+ q = self.q_in_proj[name](query)
394
+ q *= self.scaling
395
+ q = (
396
+ q.contiguous()
397
+ .view(-1, bsz * self.num_heads, self.head_dim)
398
+ .transpose(0, 1)
399
+ )
400
+ else:
401
+ q = None
402
+
403
+ if key is not None:
404
+ bsz = key.size(1)
405
+ k = self.k_in_proj[name](key)
406
+ k = (
407
+ k.contiguous()
408
+ .view(-1, bsz * self.num_heads, self.head_dim)
409
+ .transpose(0, 1)
410
+ )
411
+ else:
412
+ k = None
413
+
414
+ if value is not None:
415
+ bsz = value.size(1)
416
+ v = self.v_in_proj[name](value)
417
+ v = (
418
+ v.contiguous()
419
+ .view(-1, bsz * self.num_heads, self.head_dim)
420
+ .transpose(0, 1)
421
+ )
422
+ else:
423
+ v = None
424
+
425
+ return q, k, v
426
+
427
+ def p_choose(self, query, key, key_padding_mask=None):
428
+ """
429
+ Calculating step wise prob for reading and writing
430
+ 1 to read, 0 to write
431
+
432
+ ============================================================
433
+ Expected input size
434
+ query: bsz, tgt_len, embed_dim
435
+ key: bsz, src_len, embed_dim
436
+ value: bsz, src_len, embed_dim
437
+ key_padding_mask: bsz, src_len
438
+ attn_mask: bsz, src_len
439
+ query: bsz, tgt_len, embed_dim
440
+ """
441
+
442
+ # prepare inputs
443
+ q_proj, k_proj, _ = self.input_projections(query, key, None, "monotonic")
444
+
445
+ # attention energy
446
+ attn_energy = self.attn_energy(q_proj, k_proj, key_padding_mask)
447
+
448
+ noise = 0
449
+
450
+ if self.training:
451
+ # add noise here to encourage discretness
452
+ noise = (
453
+ torch.normal(self.noise_mean, self.noise_var, attn_energy.size())
454
+ .type_as(attn_energy)
455
+ .to(attn_energy.device)
456
+ )
457
+
458
+ p_choose = torch.sigmoid(attn_energy + noise)
459
+ _, _, tgt_len, src_len = p_choose.size()
460
+
461
+ # p_choose: bsz * self.num_heads, tgt_len, src_len
462
+ return p_choose.view(-1, tgt_len, src_len)
463
+
464
+ def expected_attention(self, alpha, *args):
465
+ """
466
+ For MMA-H, beta = alpha
467
+ """
468
+ return alpha
469
+
470
+ def v_proj_output(self, value):
471
+ _, _, v_proj = self.input_projections(None, None, value, "output")
472
+ return v_proj
473
+
474
+
475
+ @register_monotonic_attention("infinite_lookback")
476
+ class MonotonicMultiheadAttentionInfiniteLookback(MonotonicMultiheadAttentionHard):
477
+ def __init__(self, args):
478
+ super().__init__(args)
479
+ self.init_soft_attention()
480
+
481
+ def init_soft_attention(self):
482
+ self.k_proj_soft = nn.Linear(self.kdim, self.embed_dim, bias=True)
483
+ self.q_proj_soft = nn.Linear(self.embed_dim, self.embed_dim, bias=True)
484
+ self.k_in_proj["soft"] = self.k_proj_soft
485
+ self.q_in_proj["soft"] = self.q_proj_soft
486
+
487
+ if self.qkv_same_dim:
488
+ # Empirically observed the convergence to be much better with
489
+ # the scaled initialization
490
+ nn.init.xavier_uniform_(
491
+ self.k_in_proj["soft"].weight, gain=1 / math.sqrt(2)
492
+ )
493
+ nn.init.xavier_uniform_(
494
+ self.q_in_proj["soft"].weight, gain=1 / math.sqrt(2)
495
+ )
496
+ else:
497
+ nn.init.xavier_uniform_(self.k_in_proj["soft"].weight)
498
+ nn.init.xavier_uniform_(self.q_in_proj["soft"].weight)
499
+
500
+ def expected_attention(
501
+ self, alpha, query, key, value, key_padding_mask, incremental_state
502
+ ):
503
+ # monotonic attention, we will calculate milk here
504
+ bsz_x_num_heads, tgt_len, src_len = alpha.size()
505
+ bsz = int(bsz_x_num_heads / self.num_heads)
506
+
507
+ q, k, _ = self.input_projections(query, key, None, "soft")
508
+ soft_energy = self.attn_energy(q, k, key_padding_mask)
509
+
510
+ assert list(soft_energy.size()) == [bsz, self.num_heads, tgt_len, src_len]
511
+
512
+ soft_energy = soft_energy.view(bsz * self.num_heads, tgt_len, src_len)
513
+
514
+ if incremental_state is not None:
515
+ monotonic_cache = self._get_monotonic_buffer(incremental_state)
516
+ monotonic_step = monotonic_cache["step"] + 1
517
+ step_offset = 0
518
+ if key_padding_mask is not None:
519
+ if key_padding_mask[:, 0].any():
520
+ # left_pad_source = True:
521
+ step_offset = key_padding_mask.sum(dim=-1, keepdim=True)
522
+ monotonic_step += step_offset
523
+ mask = lengths_to_mask(
524
+ monotonic_step.view(-1), soft_energy.size(2), 1
525
+ ).unsqueeze(1)
526
+
527
+ soft_energy = soft_energy.masked_fill(~mask.bool(), float("-inf"))
528
+ soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
529
+ exp_soft_energy = torch.exp(soft_energy)
530
+ exp_soft_energy_sum = exp_soft_energy.sum(dim=2)
531
+ beta = exp_soft_energy / exp_soft_energy_sum.unsqueeze(2)
532
+
533
+ else:
534
+ # bsz * num_heads, tgt_len, src_len
535
+ soft_energy = soft_energy - soft_energy.max(dim=2, keepdim=True)[0]
536
+ exp_soft_energy = torch.exp(soft_energy)
537
+ exp_soft_energy_cumsum = torch.cumsum(exp_soft_energy, dim=2)
538
+
539
+ if key_padding_mask is not None:
540
+ if key_padding_mask.any():
541
+ exp_soft_energy_cumsum = (
542
+ exp_soft_energy_cumsum.view(
543
+ -1, self.num_heads, tgt_len, src_len
544
+ )
545
+ .masked_fill(
546
+ key_padding_mask.unsqueeze(1).unsqueeze(1), self.eps
547
+ )
548
+ .view(-1, tgt_len, src_len)
549
+ )
550
+
551
+ inner_items = alpha / exp_soft_energy_cumsum
552
+
553
+ beta = exp_soft_energy * torch.cumsum(
554
+ inner_items.flip(dims=[2]), dim=2
555
+ ).flip(dims=[2])
556
+
557
+ beta = self.dropout_module(beta)
558
+
559
+ assert not torch.isnan(beta).any(), "NaN detected in beta."
560
+
561
+ return beta
562
+
563
+
564
+ @register_monotonic_attention("waitk")
565
+ class MonotonicMultiheadAttentionWaitk(MonotonicMultiheadAttentionInfiniteLookback):
566
+ def __init__(self, args):
567
+ super().__init__(args)
568
+ self.q_in_proj["soft"] = self.q_in_proj["monotonic"]
569
+ self.k_in_proj["soft"] = self.k_in_proj["monotonic"]
570
+ self.waitk_lagging = args.waitk_lagging
571
+ assert (
572
+ self.waitk_lagging > 0
573
+ ), f"Lagging has to been larger than 0, get {self.waitk_lagging}."
574
+
575
+ @staticmethod
576
+ def add_args(parser):
577
+ super(
578
+ MonotonicMultiheadAttentionWaitk,
579
+ MonotonicMultiheadAttentionWaitk,
580
+ ).add_args(parser)
581
+
582
+ parser.add_argument(
583
+ "--waitk-lagging", type=int, required=True, help="Wait k lagging"
584
+ )
585
+
586
+ def p_choose(
587
+ self, query, key, key_padding_mask=None, attn_mask=None, incremental_state=None
588
+ ):
589
+ """
590
+ query: bsz, tgt_len
591
+ key: bsz, src_len
592
+ key_padding_mask: bsz, src_len
593
+ """
594
+ src_len, bsz, _ = key.size()
595
+ tgt_len, bsz, _ = query.size()
596
+ p_choose = query.new_ones(bsz, tgt_len, src_len)
597
+ p_choose = torch.tril(p_choose, diagonal=self.waitk_lagging - 1)
598
+ p_choose = torch.triu(p_choose, diagonal=self.waitk_lagging - 1)
599
+
600
+ if key_padding_mask is not None and key_padding_mask[:, 0].eq(1).any():
601
+ # Left pad source
602
+ # add -1 to the end
603
+ p_choose = p_choose.masked_fill(
604
+ key_padding_mask.float().flip(1).unsqueeze(1).bool(), -1
605
+ )
606
+ p_choose = convert_padding_direction(
607
+ p_choose.view(-1, src_len).long(), padding_idx=-1, right_to_left=True
608
+ )
609
+ p_choose = p_choose.view(bsz, tgt_len, src_len).type_as(query)
610
+ # remove -1
611
+ p_choose[p_choose.eq(-1)] = 0
612
+
613
+ # Extend to each head
614
+ p_choose = (
615
+ p_choose.contiguous()
616
+ .unsqueeze(1)
617
+ .expand(-1, self.num_heads, -1, -1)
618
+ .contiguous()
619
+ .view(-1, tgt_len, src_len)
620
+ )
621
+
622
+ return p_choose